From 085d69ad0d9292bc853844202722cc7e889d4b26 Mon Sep 17 00:00:00 2001 From: Tomas Dvorak Date: Mon, 9 Sep 2024 18:33:01 +0200 Subject: [PATCH] feat(observability): propagate current instance to async scope --- src/agents/base.ts | 7 ++++--- src/context.ts | 35 ++++++++++++++++++----------------- src/llms/base.ts | 8 ++++---- src/tools/base.ts | 9 ++++----- 4 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/agents/base.ts b/src/agents/base.ts index 1eb0417d..f2061c7a 100644 --- a/src/agents/base.ts +++ b/src/agents/base.ts @@ -17,7 +17,7 @@ import { FrameworkError } from "@/errors.js"; import { AgentMeta } from "@/agents/types.js"; import { Serializable } from "@/internals/serializable.js"; -import { GetRunContext, GetRunInstance, Run, RunContext } from "@/context.js"; +import { GetRunContext, RunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; import { BaseMemory } from "@/memory/base.js"; @@ -40,13 +40,14 @@ export abstract class BaseAgent< ...[input, options]: Partial extends TOptions ? [input: TInput, options?: TOptions] : [input: TInput, options: TOptions] - ): Run, [input: TInput, options?: TOptions]> { + ) { if (this.isRunning) { throw new AgentError("Agent is already running!"); } return RunContext.enter( - { self: this, signal: options?.signal, params: [input, options] as const }, + this, + { signal: options?.signal, params: [input, options] as const }, async (context) => { try { // @ts-expect-error diff --git a/src/context.ts b/src/context.ts index e376b1e5..5c06305c 100644 --- a/src/context.ts +++ b/src/context.ts @@ -36,21 +36,21 @@ export interface RunContextCallbacks { finish: Callback; } -export type GetRunContext = T extends RunInstance ? RunContext

: never; +export type GetRunContext = T extends RunInstance ? RunContext : never; export type GetRunInstance = T extends RunInstance ? P : never; -export class Run extends LazyPromise { +export class Run extends LazyPromise { constructor( - handler: () => Promise, - protected readonly runContext: RunContext, + handler: () => Promise, + protected readonly runContext: GetRunContext, ) { super(handler); } readonly [Symbol.toStringTag] = "Promise"; - observe(fn: (emitter: Emitter) => void) { - fn(this.runContext.emitter as any); + observe(fn: (emitter: Emitter>) => void) { + fn(this.runContext.emitter); return this; } @@ -60,19 +60,18 @@ export class Run extends LazyPromise { return this; } - middleware(fn: (context: RunContext) => void) { + middleware(fn: (context: GetRunContext) => void) { fn(this.runContext); return this; } } -export interface RunContextInput { - self: RunInstance; +export interface RunContextInput

{ params: P; signal?: AbortSignal; } -export class RunContext extends Serializable { +export class RunContext extends Serializable { static #storage = new AsyncLocalStorage>(); protected readonly controller: AbortController; @@ -92,7 +91,8 @@ export class RunContext extends Serializable { } constructor( - protected readonly input: RunContextInput, + public readonly instance: T, + protected readonly input: RunContextInput

, parent?: RunContext, ) { super(); @@ -107,7 +107,7 @@ export class RunContext extends Serializable { this.controller = new AbortController(); registerSignals(this.controller, [input.signal, parent?.signal]); - this.emitter = input.self.emitter.child({ + this.emitter = instance.emitter.child>({ context: this.context, trace: { id: this.groupId, @@ -125,14 +125,15 @@ export class RunContext extends Serializable { this.controller.abort(new FrameworkError("Context destroyed.")); } - static enter( - input: RunContextInput, - fn: (context: RunContext) => Promise, + static enter( + instance: C2, + input: RunContextInput, + fn: (context: GetRunContext) => Promise, ) { const parent = RunContext.#storage.getStore(); - const runContext = new RunContext(input, parent); + const runContext = new RunContext(instance, input, parent) as GetRunContext; - return new Run(async () => { + return new Run(async () => { const emitter = runContext.emitter.child({ namespace: ["run"], creator: runContext, diff --git a/src/llms/base.ts b/src/llms/base.ts index 94ad3f5d..7b82154b 100644 --- a/src/llms/base.ts +++ b/src/llms/base.ts @@ -131,7 +131,8 @@ export abstract class BaseLLM< generate(input: TInput, options?: TGenerateOptions) { return RunContext.enter( - { self: this, params: [input, options] as const, signal: options?.signal }, + this, + { params: [input, options] as const, signal: options?.signal }, async (run) => { try { await run.emitter.emit("start", { input, options }); @@ -147,7 +148,6 @@ export abstract class BaseLLM< ...options, signal: controller.signal, }, - // @ts-expect-error wrong types run, )) { chunks.push(chunk); @@ -190,9 +190,9 @@ export abstract class BaseLLM< async *stream(input: TInput, options?: StreamGenerateOptions): AsyncStream { return yield* emitterToGenerator(async ({ emit }) => { return RunContext.enter( - { self: this, params: [input, options] as const, signal: options?.signal }, + this, + { params: [input, options] as const, signal: options?.signal }, async (run) => { - // @ts-expect-error wrong types for await (const token of this._stream(input, options ?? {}, run)) { emit(token); } diff --git a/src/tools/base.ts b/src/tools/base.ts index e41b4dd2..1897ad12 100644 --- a/src/tools/base.ts +++ b/src/tools/base.ts @@ -204,7 +204,8 @@ export abstract class Tool< run(input: ToolInputRaw, options?: TRunOptions): Promise { return RunContext.enter( - { self: this, signal: options?.signal, params: [input, options] as const }, + this, + { signal: options?.signal, params: [input, options] as const }, async (run) => { const meta = { input, options }; let errorPropagated = false; @@ -217,10 +218,8 @@ export abstract class Tool< errorPropagated = false; await run.emitter.emit("start", { ...meta }); return this.cache.enabled - ? // @ts-expect-error wrong types - await this._runCached(input, options, run) - : // @ts-expect-error wrong types - await this._run(input, options, run); + ? await this._runCached(input, options, run) + : await this._run(input, options, run); }, onError: async (error) => { errorPropagated = true;