-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement structured generators
- Loading branch information
1 parent
d819562
commit a9038e8
Showing
9 changed files
with
276 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import { | ||
AnySchemaLike, | ||
FromSchemaLike, | ||
createSchemaValidator, | ||
toJsonSchema, | ||
} from "@/internals/helpers/schema.js"; | ||
import { GenerateOptions, LLMError } from "@/llms/base.js"; | ||
import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js"; | ||
import { BaseMessage, Role } from "@/llms/primitives/message.js"; | ||
import { Retryable } from "@/internals/helpers/retryable.js"; | ||
import { PromptTemplate } from "@/template.js"; | ||
import { SchemaObject } from "ajv"; | ||
import { z } from "zod"; | ||
|
||
export interface GenerateSchemaInput<T> { | ||
maxRetries?: number; | ||
options?: T; | ||
} | ||
|
||
const templateSchema = z.object({ schema: z.string() }); | ||
|
||
export abstract class BaseDriver<TGenerateOptions extends GenerateOptions = GenerateOptions> { | ||
protected abstract template: PromptTemplate<typeof templateSchema>; | ||
protected errorTemplate = new PromptTemplate({ | ||
schema: z.object({ | ||
errors: z.string(), | ||
expected: z.string(), | ||
received: z.string(), | ||
}), | ||
template: `Generated response does not match the expected schema! | ||
Validation Errors: "{{errors}}"`, | ||
}); | ||
|
||
constructor(protected llm: ChatLLM<ChatLLMOutput, TGenerateOptions>) {} | ||
|
||
protected abstract parseResponse(textResponse: string): unknown; | ||
protected abstract schemaToString(schema: SchemaObject): Promise<string> | string; | ||
protected guided(schema: SchemaObject): GenerateOptions["guided"] | undefined { | ||
return undefined; | ||
} | ||
|
||
async generate<T extends AnySchemaLike>( | ||
schema: T, | ||
input: BaseMessage[], | ||
{ maxRetries = 3, options }: GenerateSchemaInput<TGenerateOptions> = {}, | ||
): Promise<FromSchemaLike<T>> { | ||
const jsonSchema = toJsonSchema(schema); | ||
const validator = createSchemaValidator(jsonSchema); | ||
const schemaString = await this.schemaToString(jsonSchema); | ||
|
||
const messages: BaseMessage[] = [ | ||
BaseMessage.of({ | ||
role: Role.SYSTEM, | ||
text: this.template.render({ schema: schemaString }), | ||
}), | ||
...input, | ||
]; | ||
|
||
return new Retryable({ | ||
executor: async () => { | ||
const rawResponse = await this.llm.generate(messages, { | ||
guided: this.guided(jsonSchema), | ||
...options, | ||
} as TGenerateOptions); | ||
const textResponse = rawResponse.getTextContent(); | ||
let parsedResponse: any; | ||
|
||
try { | ||
parsedResponse = this.parseResponse(textResponse); | ||
} catch (error) { | ||
throw new LLMError(`Failed to parse the generated response.`, [], { | ||
isFatal: false, | ||
isRetryable: true, | ||
context: { error: (error as Error).message, received: textResponse }, | ||
}); | ||
} | ||
|
||
const success = validator(parsedResponse); | ||
if (!success) { | ||
const context = { | ||
expected: schemaString, | ||
received: textResponse, | ||
errors: JSON.stringify(validator.errors ?? []), | ||
}; | ||
|
||
messages.push( | ||
BaseMessage.of({ | ||
role: Role.USER, | ||
text: this.errorTemplate.render(context), | ||
}), | ||
); | ||
throw new LLMError( | ||
"Failed to generate a structured response adhering to the provided schema.", | ||
[], | ||
{ | ||
isFatal: false, | ||
isRetryable: true, | ||
context, | ||
}, | ||
); | ||
} | ||
return parsedResponse as FromSchemaLike<T>; | ||
}, | ||
config: { | ||
signal: options?.signal, | ||
maxRetries, | ||
}, | ||
}).get(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import { parseBrokenJson } from "@/internals/helpers/schema.js"; | ||
import { GenerateOptions } from "@/llms/base.js"; | ||
import { PromptTemplate } from "@/template.js"; | ||
import { BaseDriver } from "./base.js"; | ||
import { SchemaObject } from "ajv"; | ||
import { z } from "zod"; | ||
|
||
export class JsonDriver< | ||
TGenerateOptions extends GenerateOptions = GenerateOptions, | ||
> extends BaseDriver<TGenerateOptions> { | ||
protected template = new PromptTemplate({ | ||
schema: z.object({ | ||
schema: z.string(), | ||
}), | ||
template: `You are a helpful assistant that generates only valid JSON adhering to the following JSON Schema. | ||
\`\`\` | ||
{{schema}} | ||
\`\`\` | ||
IMPORTANT: Every message must be a parsable JSON string without additional output. | ||
`, | ||
}); | ||
|
||
protected parseResponse(textResponse: string): unknown { | ||
return parseBrokenJson(textResponse); | ||
} | ||
|
||
protected schemaToString(schema: SchemaObject): string { | ||
return JSON.stringify(schema, null, 2); | ||
} | ||
|
||
protected guided(schema: SchemaObject) { | ||
return { json: schema } as const; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import { parseBrokenJson } from "@/internals/helpers/schema.js"; | ||
import { GenerateOptions } from "@/llms/base.js"; | ||
import { PromptTemplate } from "@/template.js"; | ||
import { BaseDriver } from "./base.js"; | ||
import * as jsonSchemaToTypescript from "json-schema-to-typescript"; | ||
import { SchemaObject } from "ajv"; | ||
import { z } from "zod"; | ||
|
||
export class TypescriptDriver< | ||
TGenerateOptions extends GenerateOptions = GenerateOptions, | ||
> extends BaseDriver<TGenerateOptions> { | ||
protected template = new PromptTemplate({ | ||
schema: z.object({ | ||
schema: z.string(), | ||
}), | ||
template: `You are a helpful assistant that generates only valid JSON adhering to the following TypeScript type. | ||
\`\`\` | ||
{{schema}} | ||
\`\`\` | ||
IMPORTANT: Every message must be a parsable JSON string without additional output. | ||
`, | ||
}); | ||
|
||
protected parseResponse(textResponse: string): unknown { | ||
return parseBrokenJson(textResponse); | ||
} | ||
|
||
protected async schemaToString(schema: SchemaObject): Promise<string> { | ||
return await jsonSchemaToTypescript.compile(schema, "Output"); | ||
} | ||
|
||
protected guided(schema: SchemaObject) { | ||
return { json: schema } as const; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import { GenerateOptions } from "@/llms/base.js"; | ||
import { PromptTemplate } from "@/template.js"; | ||
import { BaseDriver } from "./base.js"; | ||
import yaml from "js-yaml"; | ||
import { SchemaObject } from "ajv"; | ||
import { z } from "zod"; | ||
|
||
export class YamlDriver< | ||
TGenerateOptions extends GenerateOptions = GenerateOptions, | ||
> extends BaseDriver<TGenerateOptions> { | ||
protected template = new PromptTemplate({ | ||
schema: z.object({ | ||
schema: z.string(), | ||
}), | ||
template: `You are a helpful assistant that generates only valid YAML adhering to the following schema. | ||
\`\`\` | ||
{{schema}} | ||
\`\`\` | ||
IMPORTANT: Every message must be a parsable YAML string without additional output. | ||
`, | ||
}); | ||
|
||
protected parseResponse(textResponse: string): unknown { | ||
return yaml.load(textResponse); | ||
} | ||
|
||
protected schemaToString(schema: SchemaObject): string { | ||
return yaml.dump(schema); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.