diff --git a/examples/agents/bee_advanced.ts b/examples/agents/bee_advanced.ts index 562b090f..f57c1ef6 100644 --- a/examples/agents/bee_advanced.ts +++ b/examples/agents/bee_advanced.ts @@ -8,16 +8,6 @@ import { DuckDuckGoSearchToolSearchType, } from "bee-agent-framework/tools/search/duckDuckGoSearch"; import { OpenMeteoTool } from "bee-agent-framework/tools/weather/openMeteo"; -import { - BeeAssistantPrompt, - BeeSchemaErrorPrompt, - BeeSystemPrompt, - BeeToolErrorPrompt, - BeeToolInputErrorPrompt, - BeeToolNoResultsPrompt, - BeeUserEmptyPrompt, -} from "bee-agent-framework/agents/bee/prompts"; -import { PromptTemplate } from "bee-agent-framework/template"; import { BAMChatLLM } from "bee-agent-framework/adapters/bam/chat"; import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory"; import { z } from "zod"; @@ -32,40 +22,30 @@ const agent = new BeeAgent({ memory: new UnconstrainedMemory(), // You can override internal templates templates: { - user: new PromptTemplate({ - schema: z - .object({ - input: z.string(), - }) - .passthrough(), - template: `User: {{input}}`, - }), - system: BeeSystemPrompt.fork((old) => ({ - ...old, - defaults: { - instructions: "You are a helpful assistant that uses tools to answer questions.", - }, - })), - toolError: BeeToolErrorPrompt, - toolInputError: BeeToolInputErrorPrompt, - toolNoResultError: BeeToolNoResultsPrompt.fork((old) => ({ - ...old, - template: `${old.template}\nPlease reformat your input.`, - })), - toolNotFoundError: new PromptTemplate({ - schema: z - .object({ - tools: z.array(z.object({ name: z.string() }).passthrough()), - }) - .passthrough(), - template: `Tool does not exist! + user: (template) => + template.fork((config) => { + config.schema = z.object({ input: z.string() }).passthrough(); + config.template = `User: {{input}}`; + }), + system: (template) => + template.fork((config) => { + config.defaults.instructions = + "You are a helpful assistant that uses tools to answer questions."; + }), + toolNoResultError: (template) => + template.fork((config) => { + config.template += `\nPlease reformat your input.`; + }), + toolNotFoundError: (template) => + template.fork((config) => { + config.schema = z + .object({ tools: z.array(z.object({ name: z.string() }).passthrough()) }) + .passthrough(); + config.template = `Tool does not exist! {{#tools.length}} Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}} -{{/tools.length}}`, - }), - schemaError: BeeSchemaErrorPrompt, - assistant: BeeAssistantPrompt, - userEmpty: BeeUserEmptyPrompt, +{{/tools.length}}`; + }), }, tools: [ new DuckDuckGoSearchTool({ diff --git a/src/agents/bee/agent.ts b/src/agents/bee/agent.ts index 5ed76105..ae677267 100644 --- a/src/agents/bee/agent.ts +++ b/src/agents/bee/agent.ts @@ -37,12 +37,18 @@ import { GraniteRunner } from "@/agents/bee/runners/granite/runner.js"; import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { ValueError } from "@/errors.js"; +export type BeeTemplateFactory = ( + template: BeeAgentTemplates[K], +) => BeeAgentTemplates[K]; + export interface BeeInput { llm: ChatLLM; tools: AnyTool[]; memory: BaseMemory; meta?: Omit; - templates?: Partial; + templates?: Partial<{ + [K in keyof BeeAgentTemplates]: BeeAgentTemplates[K] | BeeTemplateFactory; + }>; execution?: BeeAgentExecutionConfig; } diff --git a/src/agents/bee/runners/base.ts b/src/agents/bee/runners/base.ts index dc4ec42b..a7a630c0 100644 --- a/src/agents/bee/runners/base.ts +++ b/src/agents/bee/runners/base.ts @@ -15,9 +15,9 @@ */ import { Serializable } from "@/internals/serializable.js"; +import type { BeeAgentTemplates } from "@/agents/bee/types.js"; import { BeeAgentRunIteration, - BeeAgentTemplates, BeeCallbacks, BeeIterationToolResult, BeeMeta, @@ -31,6 +31,9 @@ import { shallowCopy } from "@/serializer/utils.js"; import { BaseMemory } from "@/memory/base.js"; import { GetRunContext } from "@/context.js"; import { Emitter } from "@/emitter/emitter.js"; +import { Cache } from "@/cache/decoratorCache.js"; +import { getProp, mapObj } from "@/internals/helpers/object.js"; +import { PromptTemplate } from "@/template.js"; export interface BeeRunnerLLMInput { meta: BeeMeta; @@ -95,7 +98,18 @@ export abstract class BaseRunner extends Serializable { abstract tool(input: BeeRunnerToolInput): Promise<{ output: string; success: boolean }>; - abstract get templates(): BeeAgentTemplates; + public abstract get defaultTemplates(): BeeAgentTemplates; + + @Cache({ enumerable: false }) + public get templates(): BeeAgentTemplates { + return mapObj(this.defaultTemplates)((key, defaultTemplate) => { + const override = getProp(this.input.templates, [key], defaultTemplate); + if (override instanceof PromptTemplate) { + return override; + } + return override(defaultTemplate) ?? defaultTemplate; + }); + } protected abstract initMemory(input: BeeRunInput): Promise; diff --git a/src/agents/bee/runners/default/runner.ts b/src/agents/bee/runners/default/runner.ts index 4b786959..fb5a1dd6 100644 --- a/src/agents/bee/runners/default/runner.ts +++ b/src/agents/bee/runners/default/runner.ts @@ -15,12 +15,7 @@ */ import { BaseRunner, BeeRunnerLLMInput, BeeRunnerToolInput } from "@/agents/bee/runners/base.js"; -import type { - BeeAgentRunIteration, - BeeAgentTemplates, - BeeParserInput, - BeeRunInput, -} from "@/agents/bee/types.js"; +import type { BeeAgentRunIteration, BeeParserInput, BeeRunInput } from "@/agents/bee/types.js"; import { Retryable } from "@/internals/helpers/retryable.js"; import { AgentError } from "@/agents/base.js"; import { @@ -48,6 +43,21 @@ import { Cache } from "@/cache/decoratorCache.js"; import { shallowCopy } from "@/serializer/utils.js"; export class DefaultRunner extends BaseRunner { + @Cache({ enumerable: false }) + public get defaultTemplates() { + return { + system: BeeSystemPrompt, + assistant: BeeAssistantPrompt, + user: BeeUserPrompt, + schemaError: BeeSchemaErrorPrompt, + toolNotFoundError: BeeToolNotFoundPrompt, + toolError: BeeToolErrorPrompt, + toolInputError: BeeToolInputErrorPrompt, + userEmpty: BeeUserEmptyPrompt, + toolNoResultError: BeeToolNoResultsPrompt, + }; + } + static { this.register(); } @@ -369,23 +379,6 @@ export class DefaultRunner extends BaseRunner { return memory; } - @Cache({ enumerable: false }) - get templates(): BeeAgentTemplates { - const customTemplates = this.input.templates ?? {}; - - return { - system: customTemplates.system ?? BeeSystemPrompt, - assistant: customTemplates.assistant ?? BeeAssistantPrompt, - user: customTemplates.user ?? BeeUserPrompt, - userEmpty: customTemplates.userEmpty ?? BeeUserEmptyPrompt, - toolError: customTemplates.toolError ?? BeeToolErrorPrompt, - toolInputError: customTemplates.toolInputError ?? BeeToolInputErrorPrompt, - toolNoResultError: customTemplates.toolNoResultError ?? BeeToolNoResultsPrompt, - toolNotFoundError: customTemplates.toolNotFoundError ?? BeeToolNotFoundPrompt, - schemaError: customTemplates.schemaError ?? BeeSchemaErrorPrompt, - }; - } - protected createParser(tools: AnyTool[]) { const parserRegex = isEmpty(tools) ? new RegExp(`Thought: .+\\nFinal Answer: [\\s\\S]+`) diff --git a/src/agents/bee/runners/granite/prompts.ts b/src/agents/bee/runners/granite/prompts.ts index 97fa1cf5..5e11deb9 100644 --- a/src/agents/bee/runners/granite/prompts.ts +++ b/src/agents/bee/runners/granite/prompts.ts @@ -24,28 +24,20 @@ import { BeeUserPrompt, } from "@/agents/bee/prompts.js"; -export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork((config) => ({ - ...config, - template: `{{#thought}}Thought: {{.}}\n{{/thought}}{{#toolName}}Tool Name: {{.}}\n{{/toolName}}{{#toolInput}}Tool Input: {{.}}\n{{/toolInput}}{{#finalAnswer}}Final Answer: {{.}}{{/finalAnswer}}`, -})); +export const GraniteBeeAssistantPrompt = BeeAssistantPrompt.fork((config) => { + config.template = `{{#thought}}Thought: {{.}}\n{{/thought}}{{#toolName}}Tool Name: {{.}}\n{{/toolName}}{{#toolInput}}Tool Input: {{.}}\n{{/toolInput}}{{#finalAnswer}}Final Answer: {{.}}{{/finalAnswer}}`; +}); -export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork((config) => ({ - ...config, - defaults: { - ...config.defaults, - instructions: "", - }, - functions: { - ...config.functions, - formatDate: function () { - const date = this.createdAt ? new Date(this.createdAt) : new Date(); - return new Intl.DateTimeFormat("en-US", { - dateStyle: "full", - timeStyle: "medium", - }).format(date); - }, - }, - template: `You are an AI assistant. +export const GraniteBeeSystemPrompt = BeeSystemPrompt.fork((config) => { + config.defaults.instructions = ""; + config.functions.formatDate = function () { + const date = this.createdAt ? new Date(this.createdAt) : new Date(); + return new Intl.DateTimeFormat("en-US", { + dateStyle: "full", + timeStyle: "medium", + }).format(date); + }; + config.template = `You are an AI assistant. When the user sends a message figure out a solution and provide a final answer. {{#tools.length}} You have access to a set of tools that can be used to retrieve information and perform actions. @@ -85,38 +77,33 @@ You do not need a tool to get the current Date and Time. Use the information ava # Additional instructions {{.}} {{/instructions}} -`, -})); +`; +}); -export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork((config) => ({ - ...config, - template: `Error: The generated response does not adhere to the communication structure mentioned in the system prompt. -You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then 'Tool Input' or 'Thought' followed by 'Final Answer'.`, -})); +export const GraniteBeeSchemaErrorPrompt = BeeSchemaErrorPrompt.fork((config) => { + config.template = `Error: The generated response does not adhere to the communication structure mentioned in the system prompt. +You communicate only in instruction lines. Valid instruction lines are 'Thought' followed by 'Tool Name' and then 'Tool Input' or 'Thought' followed by 'Final Answer'.`; +}); -export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => ({ - ...config, - template: `{{input}}`, -})); +export const GraniteBeeUserPrompt = BeeUserPrompt.fork((config) => { + config.template = `{{input}}`; +}); -export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork((config) => ({ - ...config, - template: `Tool does not exist! +export const GraniteBeeToolNotFoundPrompt = BeeToolNotFoundPrompt.fork((config) => { + config.template = `Tool does not exist! {{#tools.length}} Use one of the following tools: {{#trim}}{{#tools}}{{name}},{{/tools}}{{/trim}} -{{/tools.length}}`, -})); +{{/tools.length}}`; +}); -export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork((config) => ({ - ...config, - template: `The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it. +export const GraniteBeeToolErrorPrompt = BeeToolErrorPrompt.fork((config) => { + config.template = `The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it. -{{reason}}`, -})); +{{reason}}`; +}); -export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork((config) => ({ - ...config, - template: `{{reason}} +export const GraniteBeeToolInputErrorPrompt = BeeToolInputErrorPrompt.fork((config) => { + config.template = `{{reason}} -HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.`, -})); +HINT: If you're convinced that the input was correct but the tool cannot process it then use a different tool or say I don't know.`; +}); diff --git a/src/agents/bee/runners/granite/runner.ts b/src/agents/bee/runners/granite/runner.ts index e928ff34..2ac8b26d 100644 --- a/src/agents/bee/runners/granite/runner.ts +++ b/src/agents/bee/runners/granite/runner.ts @@ -15,16 +15,11 @@ */ import { BaseMessage, Role } from "@/llms/primitives/message.js"; -import type { AnyTool } from "@/tools/base.js"; import { isEmpty } from "remeda"; +import type { AnyTool } from "@/tools/base.js"; import { DefaultRunner } from "@/agents/bee/runners/default/runner.js"; import { BaseMemory } from "@/memory/base.js"; -import type { - BeeAgentTemplates, - BeeParserInput, - BeeRunInput, - BeeRunOptions, -} from "@/agents/bee/types.js"; +import type { BeeParserInput, BeeRunInput, BeeRunOptions } from "@/agents/bee/types.js"; import { BeeAgent, BeeInput } from "@/agents/bee/agent.js"; import type { GetRunContext } from "@/context.js"; import { @@ -36,9 +31,26 @@ import { GraniteBeeToolNotFoundPrompt, GraniteBeeUserPrompt, } from "@/agents/bee/runners/granite/prompts.js"; +import { BeeToolNoResultsPrompt, BeeUserEmptyPrompt } from "@/agents/bee/prompts.js"; import { Cache } from "@/cache/decoratorCache.js"; export class GraniteRunner extends DefaultRunner { + @Cache({ enumerable: false }) + public get defaultTemplates() { + return { + system: GraniteBeeSystemPrompt, + assistant: GraniteBeeAssistantPrompt, + user: GraniteBeeUserPrompt, + schemaError: GraniteBeeSchemaErrorPrompt, + toolNotFoundError: GraniteBeeToolNotFoundPrompt, + toolError: GraniteBeeToolErrorPrompt, + toolInputError: GraniteBeeToolInputErrorPrompt, + // Note: These are from bee + userEmpty: BeeUserEmptyPrompt, + toolNoResultError: BeeToolNoResultsPrompt, + }; + } + static { this.register(); } @@ -89,22 +101,6 @@ export class GraniteRunner extends DefaultRunner { return memory; } - @Cache({ enumerable: false }) - get templates(): BeeAgentTemplates { - const customTemplates = this.input.templates ?? {}; - - return { - ...super.templates, - user: customTemplates.user ?? GraniteBeeUserPrompt, - system: customTemplates.system ?? GraniteBeeSystemPrompt, - assistant: customTemplates.assistant ?? GraniteBeeAssistantPrompt, - schemaError: customTemplates.schemaError ?? GraniteBeeSchemaErrorPrompt, - toolNotFoundError: customTemplates.toolNotFoundError ?? GraniteBeeToolNotFoundPrompt, - toolError: customTemplates.toolError ?? GraniteBeeToolErrorPrompt, - toolInputError: customTemplates.toolInputError ?? GraniteBeeToolInputErrorPrompt, - }; - } - protected createParser(tools: AnyTool[]) { const { parser } = super.createParser(tools); diff --git a/src/experimental/workflows/agent.ts b/src/experimental/workflows/agent.ts index 7d7acd40..4f3b22fa 100644 --- a/src/experimental/workflows/agent.ts +++ b/src/experimental/workflows/agent.ts @@ -19,7 +19,6 @@ import { Workflow, WorkflowRunOptions } from "@/experimental/workflows/workflow. import { BaseMessage } from "@/llms/primitives/message.js"; import { AnyTool } from "@/tools/base.js"; import { AnyChatLLM } from "@/llms/chat.js"; -import { BeeSystemPrompt } from "@/agents/bee/prompts.js"; import { BaseMemory, ReadOnlyMemory } from "@/memory/base.js"; import { z } from "zod"; import { UnconstrainedMemory } from "@/memory/unconstrainedMemory.js"; @@ -100,13 +99,10 @@ export class AgentWorkflow { execution: input.execution, ...(input.instructions && { templates: { - system: BeeSystemPrompt.fork((config) => ({ - ...config, - defaults: { - ...config.defaults, - instructions: input.instructions || config.defaults.instructions, - }, - })), + system: (template) => + template.fork((config) => { + config.defaults.instructions = input.instructions || config.defaults.instructions; + }), }, }), }); diff --git a/src/internals/helpers/object.ts b/src/internals/helpers/object.ts index 7354501b..944de46c 100644 --- a/src/internals/helpers/object.ts +++ b/src/internals/helpers/object.ts @@ -147,3 +147,14 @@ export function customMerge>( } return finalResult; } + +export function mapObj(obj: T) { + return function (fn: (key: K, value: T[K]) => T[K]): T { + const updated: T = Object.assign({}, obj); + for (const pair of Object.entries(obj)) { + const [key, value] = pair as [K, T[K]]; + updated[key] = fn(key, value); + } + return updated; + }; +} diff --git a/src/template.test.ts b/src/template.test.ts index 4b495d38..e80161f5 100644 --- a/src/template.test.ts +++ b/src/template.test.ts @@ -15,7 +15,7 @@ */ import { PromptTemplateError, PromptTemplate, ValidationPromptTemplateError } from "@/template.js"; -import { z } from "zod"; +import { z, ZodType } from "zod"; describe("Prompt Template", () => { describe("Rendering", () => { @@ -187,7 +187,20 @@ describe("Prompt Template", () => { expect(cloned).toEqual(template); }); - it("Forks", () => { + it.each([ + (template: PromptTemplate) => + template.fork((config) => ({ + ...config, + template: "Hi {{name}}!", + customTags: ["{{", "}}"], + functions: { formatDate: () => "Today" }, + })), + (template: PromptTemplate) => + template.fork((config) => { + config.template = "Hi {{name}}!"; + config.customTags = ["{{", "}}"]; + }), + ])("Forks", (forkFn) => { const template = new PromptTemplate({ template: `Hello <>!`, schema: z.object({ @@ -196,13 +209,9 @@ describe("Prompt Template", () => { customTags: ["<<", ">>"], escape: false, }); - const forked = template.fork((config) => ({ - ...config, - template: "Hello {{name}}!", - customTags: ["{{", "}}"], - })); - - expect(template.render({ name: "Tomas" })).toEqual(forked.render({ name: "Tomas" })); + const forked = forkFn(template); + expect(template.render({ name: "Tomas" })).toEqual("Hello Tomas!"); + expect(forked.render({ name: "Tomas" })).toEqual("Hi Tomas!"); }); }); test("Custom function", () => { diff --git a/src/template.ts b/src/template.ts index 05f3dbcc..76ed3ab5 100644 --- a/src/template.ts +++ b/src/template.ts @@ -14,16 +14,14 @@ * limitations under the License. */ -import { FrameworkError } from "@/errors.js"; +import { FrameworkError, ValueError } from "@/errors.js"; import { ObjectLike, PlainObject } from "@/internals/types.js"; -import * as R from "remeda"; +import { clone, identity, isPlainObject, pickBy } from "remeda"; import Mustache from "mustache"; import { Serializable } from "@/internals/serializable.js"; import { z, ZodType } from "zod"; import { createSchemaValidator, toJsonSchema } from "@/internals/helpers/schema.js"; import type { SchemaObject, ValidateFunction } from "ajv"; -import { shallowCopy } from "@/serializer/utils.js"; -import { pickBy } from "remeda"; import { getProp } from "@/internals/helpers/object.js"; type PostInfer = T extends PlainObject @@ -55,9 +53,9 @@ type PromptTemplateConstructor = N extends ZodType } : Omit, "schema"> & { schema: T | SchemaObject }; -type Customizer = ( - config: Required>, -) => PromptTemplateConstructor; +type Customizer = + | ((config: Required>) => PromptTemplateConstructor) + | ((config: Required>) => void); export class PromptTemplateError extends FrameworkError { template: PromptTemplate; @@ -122,8 +120,11 @@ export class PromptTemplate extends Serializable { fork( customizer: Customizer | Customizer, ): PromptTemplate { - const config = shallowCopy(this.config); + const config = clone(this.config); const newConfig = customizer?.(config) ?? config; + if (!isPlainObject(newConfig)) { + throw new ValueError("Return type from customizer must be a config or nothing."); + } return new PromptTemplate(newConfig); } @@ -147,7 +148,7 @@ export class PromptTemplate extends Serializable { { tags: this.config.customTags, ...(!this.config.escape && { - escape: R.identity(), + escape: identity(), }), }, );