From 13604c63405feaaa85fb3d14734781174c0e04c9 Mon Sep 17 00:00:00 2001 From: Tomas Pilar Date: Thu, 16 Jan 2025 10:41:39 +0100 Subject: [PATCH] fixup! Signed-off-by: Tomas Pilar --- src/chat/chat.module.ts | 10 +++++-- src/chat/chat.service.ts | 35 +++++++++---------------- src/chat/dtos/chat-completion-create.ts | 6 +++++ src/server.ts | 2 ++ src/streaming/pubsub.ts | 8 ++---- src/streaming/sse.ts | 5 ++-- 6 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/chat/chat.module.ts b/src/chat/chat.module.ts index 4aa8a4c..1a6bca3 100644 --- a/src/chat/chat.module.ts +++ b/src/chat/chat.module.ts @@ -20,7 +20,8 @@ import { StatusCodes } from 'http-status-codes'; import { ChatCompletionCreateBody, chatCompletionCreateBodySchema, - chatCompletionCreateResponseSchema + chatCompletionCreateResponseSchema, + chatCompletionCreateResponseStreamSchema } from './dtos/chat-completion-create.js'; import { createChatCompletion } from './chat.service.js'; @@ -39,7 +40,12 @@ export const chatModule: FastifyPluginAsyncJsonSchemaToTs = async (app) => { schema: { body: chatCompletionCreateBodySchema, response: { - [StatusCodes.OK]: chatCompletionCreateResponseSchema + [StatusCodes.OK]: { + content: { + 'application/json': { schema: chatCompletionCreateResponseSchema }, + 'text/event-stream': { schema: chatCompletionCreateResponseStreamSchema } + } + } }, tags: [Tag.OPENAI_API] } diff --git a/src/chat/chat.service.ts b/src/chat/chat.service.ts index 3f124ed..b7091ca 100644 --- a/src/chat/chat.service.ts +++ b/src/chat/chat.service.ts @@ -87,9 +87,8 @@ export async function createChatCompletion({ sse.init(res); try { for await (const output of llm.stream(...args)) { - sse.send( - res, - JSON.stringify({ + sse.send(res, { + data: { id: chat.id, object: 'chat.completion.chunk', model, @@ -98,37 +97,29 @@ export async function createChatCompletion({ index, delta: { role: message.role, content: message.text } })) - } as ChatCompletionChunk) - ); + } as ChatCompletionChunk + }); chat.output = chat.output?.merge(output) ?? output; } } catch (err) { - getChatLogger().error({ err }, 'LLM generation failed'); - chat.error = err.toString(); - sse.send(res, JSON.stringify(chat.error ?? 'Internal server error')); // TODO + sse.send(res, { data: chat.error ?? 'Internal server error' }); // TODO + throw err; } finally { sse.end(res); unsub(); - await ORM.em.flush(); } } else { - try { - chat.output = await llm.generate(...args); - return toChatDto(chat); - } catch (err) { - getChatLogger().error({ err }, 'LLM generation failed'); - chat.error = err.toString(); - if (err instanceof LLMError) { - throw new APIError({ code: APIErrorCode.SERVICE_ERROR, message: err.message }); - } - throw err; - } finally { - await ORM.em.flush(); - } + chat.output = await llm.generate(...args); + return toChatDto(chat); } } catch (err) { getChatLogger().error({ err }, 'LLM generation failed'); chat.error = err.toString(); + if (err instanceof LLMError) { + throw new APIError({ code: APIErrorCode.SERVICE_ERROR, message: err.message }); + } else { + throw err; + } } finally { await ORM.em.flush(); } diff --git a/src/chat/dtos/chat-completion-create.ts b/src/chat/dtos/chat-completion-create.ts index 2b5b078..05a0be6 100644 --- a/src/chat/dtos/chat-completion-create.ts +++ b/src/chat/dtos/chat-completion-create.ts @@ -19,6 +19,7 @@ import { FromSchema, JSONSchema } from 'json-schema-to-ts'; import { ChatMessageRole } from '../constants'; import { chatCompletionSchema } from './chat-completion'; +import { chatCompletionChunkSchema } from './chat-completion-chunk'; export const chatCompletionCreateBodySchema = { type: 'object', @@ -69,3 +70,8 @@ export type ChatCompletionCreateBody = FromSchema; + +export const chatCompletionCreateResponseStreamSchema = chatCompletionChunkSchema; +export type ChatCompletionCreateResponseStream = FromSchema< + typeof chatCompletionCreateResponseStreamSchema +>; diff --git a/src/server.ts b/src/server.ts index 6f974fb..36de2d9 100644 --- a/src/server.ts +++ b/src/server.ts @@ -51,6 +51,7 @@ import { organizationUsersModule } from './administration/organization-users.mod import { apiKeysModule } from './administration/api-keys.module.js'; import { artifactsModule } from './artifacts/artifacts.module.js'; import { chatModule } from './chat/chat.module.js'; +import { embeddingsModule } from './embeddings/embeddings.module.js'; const app = fastify({ logger: fastifyLogger, @@ -97,6 +98,7 @@ try { app.register(organizationUsersModule, { prefix: '/v1' }); app.register(artifactsModule, { prefix: '/v1' }); app.register(chatModule, { prefix: '/v1' }); + app.register(embeddingsModule, { prefix: '/v1' }); app.register(uiModule, { prefix: '/v1' }); diff --git a/src/streaming/pubsub.ts b/src/streaming/pubsub.ts index 52a563f..487d3ff 100644 --- a/src/streaming/pubsub.ts +++ b/src/streaming/pubsub.ts @@ -76,7 +76,7 @@ export async function subscribeAndForward( }); client.on('message', (_, message) => { const event = JSON.parse(message) as Event; - sse.send(res, createMessage(event)); + sse.send(res, { event: event.event, data: event.data }); if (event.event === 'done' || event.event === 'error') { resolve(); } @@ -86,13 +86,9 @@ export async function subscribeAndForward( }); }); } catch (err) { - sse.send(res, createMessage({ event: 'error', data: err })); + sse.send(res, { event: 'error', data: err }); } finally { sse.end(res); } }); } - -function createMessage(event: Event): string { - return `event: ${event.event}\ndata: ${JSON.stringify(event.data)}\n\n`; -} diff --git a/src/streaming/sse.ts b/src/streaming/sse.ts index 11bcd74..5ccdb47 100644 --- a/src/streaming/sse.ts +++ b/src/streaming/sse.ts @@ -31,8 +31,9 @@ export const init = (res: FastifyReply) => { } }; -export const send = (res: FastifyReply, data: string) => { - res.raw.write(data); +export const send = (res: FastifyReply, { event, data }: { event?: string; data: any }) => { + if (event) res.raw.write(`event: ${event}\n`); + res.raw.write(`data: ${JSON.stringify(data)}\n\n`); }; export const end = (res: FastifyReply) => {