Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement structured generators #19

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/llms/structured.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ import "dotenv/config.js";
import { z } from "zod";
import { BaseMessage, Role } from "bee-agent-framework/llms/primitives/message";
import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat";
import { JsonDriver } from "bee-agent-framework/drivers/json";

const llm = new OllamaChatLLM();
const response = await llm.generateStructured(
const driver = new JsonDriver(llm);
const response = await driver.generate(
z.union([
z.object({
firstName: z.string().min(1),
Expand Down
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
"fast-xml-parser": "^4.4.1",
"header-generator": "^2.1.54",
"joplin-turndown-plugin-gfm": "^1.0.12",
"js-yaml": "^4.1.0",
"json-schema-to-typescript": "^15.0.2",
"mathjs": "^13.1.1",
"mustache": "^4.2.0",
"object-hash": "^3.0.0",
Expand Down Expand Up @@ -129,6 +131,7 @@
"@types/eslint": "^9.6.1",
"@types/eslint-config-prettier": "^6.11.3",
"@types/eslint__js": "^8.42.3",
"@types/js-yaml": "^4.0.9",
"@types/mustache": "^4",
"@types/needle": "^3.3.0",
"@types/node": "^20.16.1",
Expand Down
108 changes: 108 additions & 0 deletions src/drivers/base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import {
AnySchemaLike,
FromSchemaLike,
createSchemaValidator,
toJsonSchema,
} from "@/internals/helpers/schema.js";
import { GenerateOptions, LLMError } from "@/llms/base.js";
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import { Retryable } from "@/internals/helpers/retryable.js";
import { PromptTemplate } from "@/template.js";
import { SchemaObject } from "ajv";
import { z } from "zod";

export interface GenerateSchemaInput<T> {
maxRetries?: number;
options?: T;
}

export abstract class BaseDriver<TGenerateOptions extends GenerateOptions = GenerateOptions> {
protected abstract template: PromptTemplate.infer<{ schema: string }>;
protected errorTemplate = new PromptTemplate({
schema: z.object({
errors: z.string(),
expected: z.string(),
received: z.string(),
}),
template: `Generated response does not match the expected schema!
Validation Errors: "{{errors}}"`,
});

constructor(protected llm: ChatLLM<ChatLLMOutput, TGenerateOptions>) {}

protected abstract parseResponse(textResponse: string): unknown;
protected abstract schemaToString(schema: SchemaObject): Promise<string> | string;
protected guided(schema: SchemaObject): GenerateOptions["guided"] | undefined {

Check warning on line 36 in src/drivers/base.ts

View workflow job for this annotation

GitHub Actions / Lint & Build & Test

'schema' is defined but never used. Allowed unused args must match /^_/u
return undefined;
}

async generate<T extends AnySchemaLike>(
schema: T,
input: BaseMessage[],
{ maxRetries = 3, options }: GenerateSchemaInput<TGenerateOptions> = {},
): Promise<FromSchemaLike<T>> {
const jsonSchema = toJsonSchema(schema);
const validator = createSchemaValidator(jsonSchema);
const schemaString = await this.schemaToString(jsonSchema);

const messages: BaseMessage[] = [
BaseMessage.of({
role: Role.SYSTEM,
text: this.template.render({ schema: schemaString }),
}),
...input,
];

return new Retryable({
executor: async () => {
const rawResponse = await this.llm.generate(messages, {
guided: this.guided(jsonSchema),
...options,
} as TGenerateOptions);
const textResponse = rawResponse.getTextContent();
let parsedResponse: any;

try {
parsedResponse = this.parseResponse(textResponse);
} catch (error) {
throw new LLMError(`Failed to parse the generated response.`, [], {
isFatal: false,
isRetryable: true,
context: { error: (error as Error).message, received: textResponse },
});
}

const success = validator(parsedResponse);
if (!success) {
const context = {
expected: schemaString,
received: textResponse,
errors: JSON.stringify(validator.errors ?? []),
};

messages.push(
BaseMessage.of({
role: Role.USER,
text: this.errorTemplate.render(context),
}),
);
throw new LLMError(
"Failed to generate a structured response adhering to the provided schema.",
[],
{
isFatal: false,
isRetryable: true,
context,
},
);
}
return parsedResponse as FromSchemaLike<T>;
},
config: {
signal: options?.signal,
maxRetries,
},
}).get();
}
}
36 changes: 36 additions & 0 deletions src/drivers/json.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { parseBrokenJson } from "@/internals/helpers/schema.js";
import { GenerateOptions } from "@/llms/base.js";
import { PromptTemplate } from "@/template.js";
import { BaseDriver } from "./base.js";
import { SchemaObject } from "ajv";
import { z } from "zod";

export class JsonDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseDriver<TGenerateOptions> {
protected template = new PromptTemplate({
schema: z.object({
schema: z.string(),
}),
template: `You are a helpful assistant that generates only valid JSON adhering to the following JSON Schema.

\`\`\`
{{schema}}
\`\`\`

IMPORTANT: Every message must be a parsable JSON string without additional output.
`,
});

protected parseResponse(textResponse: string): unknown {
return parseBrokenJson(textResponse);
}

protected schemaToString(schema: SchemaObject): string {
return JSON.stringify(schema, null, 2);
}

protected guided(schema: SchemaObject) {
return { json: schema } as const;
}
}
37 changes: 37 additions & 0 deletions src/drivers/typescript.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { parseBrokenJson } from "@/internals/helpers/schema.js";
import { GenerateOptions } from "@/llms/base.js";
import { PromptTemplate } from "@/template.js";
import { BaseDriver } from "./base.js";
import * as jsonSchemaToTypescript from "json-schema-to-typescript";
import { SchemaObject } from "ajv";
import { z } from "zod";

export class TypescriptDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseDriver<TGenerateOptions> {
protected template = new PromptTemplate({
schema: z.object({
schema: z.string(),
}),
template: `You are a helpful assistant that generates only valid JSON adhering to the following TypeScript type.

\`\`\`
{{schema}}
\`\`\`

IMPORTANT: Every message must be a parsable JSON string without additional output.
`,
});

protected parseResponse(textResponse: string): unknown {
return parseBrokenJson(textResponse);
}

protected async schemaToString(schema: SchemaObject): Promise<string> {
return await jsonSchemaToTypescript.compile(schema, "Output");
}

protected guided(schema: SchemaObject) {
return { json: schema } as const;
}
}
32 changes: 32 additions & 0 deletions src/drivers/yaml.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { GenerateOptions } from "@/llms/base.js";
import { PromptTemplate } from "@/template.js";
import { BaseDriver } from "./base.js";
import yaml from "js-yaml";
import { SchemaObject } from "ajv";
import { z } from "zod";

export class YamlDriver<
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseDriver<TGenerateOptions> {
protected template = new PromptTemplate({
schema: z.object({
schema: z.string(),
}),
template: `You are a helpful assistant that generates only valid YAML adhering to the following schema.

\`\`\`
{{schema}}
\`\`\`

IMPORTANT: Every message must be a parsable YAML string without additional output.
`,
});

protected parseResponse(textResponse: string): unknown {
return yaml.load(textResponse);
}

protected schemaToString(schema: SchemaObject): string {
return yaml.dump(schema);
}
}
89 changes: 3 additions & 86 deletions src/llms/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,97 +14,14 @@
* limitations under the License.
*/

import { BaseLLM, BaseLLMOutput, GenerateOptions, LLMError } from "@/llms/base.js";
import { BaseMessage, Role } from "@/llms/primitives/message.js";
import {
AnySchemaLike,
createSchemaValidator,
FromSchemaLike,
parseBrokenJson,
toJsonSchema,
} from "@/internals/helpers/schema.js";
import { Retryable } from "@/internals/helpers/retryable.js";
import { GeneratedStructuredErrorTemplate, GeneratedStructuredTemplate } from "@/llms/prompts.js";
import { BaseLLM, BaseLLMOutput, GenerateOptions } from "@/llms/base.js";
import { BaseMessage } from "@/llms/primitives/message.js";

export abstract class ChatLLMOutput extends BaseLLMOutput {
abstract get messages(): readonly BaseMessage[];
}

export interface GenerateSchemaInput<T> {
template?: typeof GeneratedStructuredTemplate;
errorTemplate?: typeof GeneratedStructuredErrorTemplate;
maxRetries?: number;
options?: T;
}

export abstract class ChatLLM<
TOutput extends ChatLLMOutput,
TGenerateOptions extends GenerateOptions = GenerateOptions,
> extends BaseLLM<BaseMessage[], TOutput, TGenerateOptions> {
async generateStructured<T extends AnySchemaLike>(
schema: T,
input: BaseMessage[],
{
template = GeneratedStructuredTemplate,
errorTemplate = GeneratedStructuredErrorTemplate,
maxRetries = 3,
options,
}: GenerateSchemaInput<TGenerateOptions> = {},
): Promise<FromSchemaLike<T>> {
const jsonSchema = toJsonSchema(schema);
const validator = createSchemaValidator(jsonSchema);

const finalOptions = { ...options } as TGenerateOptions;
if (!options?.guided) {
finalOptions.guided = { json: jsonSchema };
}

const messages: BaseMessage[] = [
BaseMessage.of({
role: Role.SYSTEM,
text: template.render({
schema: JSON.stringify(jsonSchema, null, 2),
}),
}),
...input,
];

return new Retryable({
executor: async () => {
const rawResponse = await this.generate(messages, finalOptions);
const textResponse = rawResponse.getTextContent();
const jsonResponse: any = parseBrokenJson(textResponse);

const success = validator(jsonResponse);
if (!success) {
const context = {
expected: JSON.stringify(jsonSchema),
received: jsonResponse ? JSON.stringify(jsonResponse) : textResponse,
errors: JSON.stringify(validator.errors ?? []),
};

messages.push(
BaseMessage.of({
role: Role.USER,
text: errorTemplate.render(context),
}),
);
throw new LLMError(
"Failed to generate a structured response adhering to the provided schema.",
[],
{
isFatal: false,
isRetryable: true,
context,
},
);
}
return jsonResponse as FromSchemaLike<T>;
},
config: {
signal: options?.signal,
maxRetries,
},
}).get();
}
}
> extends BaseLLM<BaseMessage[], TOutput, TGenerateOptions> {}
42 changes: 0 additions & 42 deletions src/llms/prompts.ts

This file was deleted.

Loading