diff --git a/migrations/Migration20241126122701.ts b/migrations/Migration20241126122701.ts index 45dfe0f..3a337dc 100644 --- a/migrations/Migration20241126122701.ts +++ b/migrations/Migration20241126122701.ts @@ -1,3 +1,19 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + import { Migration } from '@mikro-orm/migrations-mongodb'; import { User } from '@/users/entities/user.entity'; diff --git a/migrations/Migration20241206091921.ts b/migrations/Migration20241206091921.ts index d5aa6f5..173dd5f 100644 --- a/migrations/Migration20241206091921.ts +++ b/migrations/Migration20241206091921.ts @@ -1,3 +1,19 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + import { Migration } from '@mikro-orm/migrations-mongodb'; import { Ref, ref } from '@mikro-orm/core'; diff --git a/src/jobs/bullmq.ts b/src/jobs/bullmq.ts index a3ee2aa..4b1153a 100644 --- a/src/jobs/bullmq.ts +++ b/src/jobs/bullmq.ts @@ -20,7 +20,7 @@ import { globby } from 'globby'; import { DefaultJobOptions, Job, Queue, Worker, WorkerOptions } from 'bullmq'; import { isTruthy } from 'remeda'; -import { createClient } from '../redis.js'; +import { defaultRedisConnectionOptions } from '../redis.js'; import { getLogger } from '../logger.js'; import { gateway } from '../metrics.js'; @@ -42,10 +42,11 @@ const getQueueLogger = (queueName: string, job?: Job) => const logger = getLogger(); -const connection = createClient({ +const connectionOpts = { + ...defaultRedisConnectionOptions, // https://docs.bullmq.io/guide/going-to-production#maxretriesperrequest maxRetriesPerRequest: null -}); +}; const defaultJobOptions = { removeOnComplete: true, @@ -89,6 +90,8 @@ interface CreateQueueInput { jobHandler?: (job: Job) => Promise; } +const Queues = new Map(); + export function createQueue({ name, jobsOptions, @@ -96,7 +99,7 @@ export function createQueue({ jobHandler }: CreateQueueInput) { const queue = new Queue(name, { - connection: connection.options, + connection: connectionOpts, defaultJobOptions: jobsOptions ? { ...defaultJobOptions, ...jobsOptions } : defaultJobOptions }); @@ -113,12 +116,13 @@ export function createQueue({ // We need to set autorun to false otherwise the worker might pick up stuff while ORM is not ready autorun: false, ...workerOptions, - connection: connection.options + connection: connectionOpts } ); addCallbacks(worker, queue); Workers.set(name, worker); } + Queues.set(name, queue); return { queue }; } @@ -133,9 +137,25 @@ export async function runWorkers(queueNames: QueueName[]) { logger.info({ queueNames }, `Workers started successfully`); } +export async function closeAllQueues() { + await Promise.all( + [...Queues.values()].map(async (queue) => { + if (!(await queue.isPaused())) { + await queue.close(); + } + }) + ); + logger.info('Queues shutdown successfully'); +} + export async function closeAllWorkers() { - await Promise.all([...Workers.values()].map((worker) => worker.close())); - connection.quit(); + await Promise.all( + [...Workers.values()].map(async (worker) => { + if (!worker.isPaused()) { + await worker.close(); + } + }) + ); logger.info('Workers shutdown successfully'); } diff --git a/src/rate-limit.ts b/src/rate-limit.ts index c6f48a3..57baa62 100644 --- a/src/rate-limit.ts +++ b/src/rate-limit.ts @@ -18,21 +18,21 @@ import { FastifyPluginAsync, FastifyRequest } from 'fastify'; import fp from 'fastify-plugin'; import rateLimit, { errorResponseBuilderContext } from '@fastify/rate-limit'; -import { createClient } from './redis.js'; +import { closeClient, createRedisClient } from './redis.js'; import { AuthSecret, determineAuthType, scryptSecret } from './auth/utils.js'; import { toErrorResponseDto } from './errors/plugin.js'; import { APIError, APIErrorCode } from './errors/error.entity.js'; -export const rateLimitPlugin: FastifyPluginAsync = fp.default(async (app) => { - const redis = createClient({ - /** - * "The default parameters of a redis connection are not the fastest to provide a rate-limit. We suggest to customize the connectTimeout and maxRetriesPerRequest. - * Source: https://github.com/fastify/fastify-rate-limit - */ - connectTimeout: 1000, // 500 was too low, getting ETIMEDOUT - maxRetriesPerRequest: 1 - }); +const redis = createRedisClient({ + /** + * "The default parameters of a redis connection are not the fastest to provide a rate-limit. We suggest to customize the connectTimeout and maxRetriesPerRequest. + * Source: https://github.com/fastify/fastify-rate-limit + */ + connectTimeout: 1000, // 500 was too low, getting ETIMEDOUT + maxRetriesPerRequest: 1 +}); +export const rateLimitPlugin: FastifyPluginAsync = fp.default(async (app) => { await app.register(rateLimit, { global: true, max: 25, @@ -67,4 +67,5 @@ export const rateLimitPlugin: FastifyPluginAsync = fp.default(async (app) => { } } }); + app.addHook('onClose', () => closeClient(redis)); }); diff --git a/src/redis.ts b/src/redis.ts index db9a858..4c69eed 100644 --- a/src/redis.ts +++ b/src/redis.ts @@ -15,31 +15,57 @@ */ import { Redis, RedisOptions } from 'ioredis'; +import { parseURL } from 'ioredis/built/utils'; import { REDIS_CA_CERT, REDIS_CACHE_CA_CERT, REDIS_CACHE_URL, REDIS_URL } from './config.js'; -export function createClient(opts?: Partial): Redis { - const client = new Redis(REDIS_URL, { - tls: - REDIS_URL.startsWith('rediss') && REDIS_CA_CERT - ? { - ca: Buffer.from(REDIS_CA_CERT) - } - : undefined, - ...opts - }); - return client; +export const defaultRedisConnectionOptions: RedisOptions = { + ...parseURL(REDIS_URL), + tls: + REDIS_URL.startsWith('rediss') && REDIS_CA_CERT ? { ca: Buffer.from(REDIS_CA_CERT) } : undefined +}; + +export const defaultRedisCacheConnectionOptions: RedisOptions = { + ...parseURL(REDIS_CACHE_URL), + tls: + REDIS_CACHE_URL.startsWith('rediss') && REDIS_CACHE_CA_CERT + ? { ca: Buffer.from(REDIS_CACHE_CA_CERT) } + : undefined, + connectTimeout: 1000, + maxRetriesPerRequest: 1 +}; + +export const sharedRedisClient = new Redis(defaultRedisConnectionOptions); +export const sharedRedisCacheClient = new Redis(defaultRedisCacheConnectionOptions); + +const CLIENTS: Redis[] = [sharedRedisClient, sharedRedisCacheClient]; + +export async function withRedisClient( + asyncCallback: (redis: Redis) => Promise, + opts?: Partial +) { + const client = new Redis(REDIS_URL, { ...defaultRedisConnectionOptions, ...opts }); + try { + return await asyncCallback(client); + } finally { + await closeClient(client); + } } -export function createCacheClient(opts?: Partial): Redis { - const client = new Redis(REDIS_CACHE_URL, { - tls: - REDIS_URL.startsWith('rediss') && REDIS_CACHE_CA_CERT - ? { - ca: Buffer.from(REDIS_CACHE_CA_CERT) - } - : undefined, - ...opts - }); +export function createRedisClient(opts?: Partial) { + const client = new Redis({ ...defaultRedisConnectionOptions, ...opts }); + CLIENTS.push(client); return client; } + +export async function closeClient(client: Redis) { + if (client.status !== 'end') { + await new Promise((resolve) => { + client.quit(() => resolve()); + }); + } +} + +export async function closeAllClients() { + await Promise.all(CLIENTS.map((client) => closeClient(client))); +} diff --git a/src/runs/execution/event-handlers/streaming.ts b/src/runs/execution/event-handlers/streaming.ts index 54fe1c8..13dccfb 100644 --- a/src/runs/execution/event-handlers/streaming.ts +++ b/src/runs/execution/event-handlers/streaming.ts @@ -15,7 +15,7 @@ */ import { FrameworkError, Version } from 'bee-agent-framework'; -import { EventMeta, Emitter, Callback } from 'bee-agent-framework/emitter/emitter'; +import { Callback, Emitter, EventMeta } from 'bee-agent-framework/emitter/emitter'; import { ref } from '@mikro-orm/core'; import { Role } from 'bee-agent-framework/llms/primitives/message'; import { BeeCallbacks } from 'bee-agent-framework/agents/bee/types'; @@ -43,11 +43,11 @@ import { RunStatus } from '@/runs/entities/run.entity.js'; import { APIError } from '@/errors/error.entity.js'; import { jobRegistry } from '@/metrics.js'; import { EmitterEvent } from '@/run-steps/entities/emitter-event.entity'; -import { createClient } from '@/redis.js'; import { createApproveChannel, toRunDto } from '@/runs/runs.service'; import { RequiredToolApprove } from '@/runs/entities/requiredToolApprove.entity'; import { ToolApprovalType } from '@/runs/entities/toolApproval.entity'; import { ToolType } from '@/tools/entities/tool/tool.entity'; +import { withRedisClient } from '@/redis.js'; const agentToolExecutionTime = new Summary({ name: 'agent_tool_execution_time_seconds', @@ -111,50 +111,52 @@ export function createBeeStreamingHandler(ctx: AgentContext) { : toolCall.type) )?.requireApproval === ToolApprovalType.ALWAYS ) { - const client = createClient(); - await new Promise((resolve, reject) => { - client.subscribe(createApproveChannel(ctx.run, toolCall), async (err) => { - try { - if (err) { - reject(err); - } else { - ctx.run.requireAction( - new RequiredToolApprove({ - toolCalls: [...(ctx.run.requiredAction?.toolCalls ?? []), toolCall] - }) - ); - await ORM.em.flush(); - await ctx.publish({ - event: 'thread.run.requires_action', - data: toRunDto(ctx.run) - }); - await ctx.publish({ - event: 'done', - data: '[DONE]' - }); - } - } catch (err) { - reject(err); - } - }); - client.on('message', async (_, approval) => { - try { - ctx.run.submitAction(); - await ORM.em.flush(); - if (approval !== 'true') { - reject( - new ToolError('User has not approved this tool to run.', [], { - isFatal: false, - isRetryable: false - }) - ); - } - resolve(true); - } catch (err) { - reject(err); - } - }); - }); + await withRedisClient( + (client) => + new Promise((resolve, reject) => { + client.subscribe(createApproveChannel(ctx.run, toolCall), async (err) => { + try { + if (err) { + reject(err); + } else { + ctx.run.requireAction( + new RequiredToolApprove({ + toolCalls: [...(ctx.run.requiredAction?.toolCalls ?? []), toolCall] + }) + ); + await ORM.em.flush(); + await ctx.publish({ + event: 'thread.run.requires_action', + data: toRunDto(ctx.run) + }); + await ctx.publish({ + event: 'done', + data: '[DONE]' + }); + } + } catch (err) { + reject(err); + } + }); + client.on('message', async (_, approval) => { + try { + ctx.run.submitAction(); + await ORM.em.flush(); + if (approval !== 'true') { + reject( + new ToolError('User has not approved this tool to run.', [], { + isFatal: false, + isRetryable: false + }) + ); + } + resolve(true); + } catch (err) { + reject(err); + } + }); + }) + ); } } diff --git a/src/runs/execution/execute.ts b/src/runs/execution/execute.ts index e097df3..977eee9 100644 --- a/src/runs/execution/execute.ts +++ b/src/runs/execution/execute.ts @@ -37,7 +37,7 @@ import { createAgentRun, createChatLLM } from './factory.js'; import { ORM } from '@/database.js'; import { getLogger } from '@/logger.js'; -import { createPublisher } from '@/streaming/pubsub.js'; +import { Publisher, withPublisher } from '@/streaming/pubsub.js'; import { APIError } from '@/errors/error.entity.js'; import { jobRegistry } from '@/metrics.js'; import { getTraceLogger } from '@/observe/utils.js'; @@ -61,7 +61,7 @@ const agentExecutionTime = new Summary({ export type AgentContext = { run: Loaded; - publish: ReturnType; + publish: Publisher; runStep?: Loaded; message?: Loaded; toolCall?: Loaded; @@ -70,8 +70,6 @@ export type AgentContext = { export async function executeRun(run: LoadedRun) { const runLogger = getLogger().child({ runId: run.id }); - const publish = createPublisher(run); - // create messages and add file ids to message content const messages = run.thread.$.messages.$.filter((message) => !message.deletedAt).map( (message) => @@ -123,85 +121,89 @@ export async function executeRun(run: LoadedRun) { run.start(); await ORM.em.flush(); - await publish({ event: 'thread.run.in_progress', data: toRunDto(run) }); - const context = { run, publish } as AgentContext; - - const tools = await getTools(run, context); - const llm = createChatLLM(run); - const memory = new TokenMemory({ llm }); - await memory.addMany(messages); - - const cancellationController = new AbortController(); - const unsub = watchForCancellation(Run, run, () => cancellationController.abort()); - - const expirationSignal = AbortSignal.timeout(dayjs(run.expiresAt).diff(dayjs(), 'milliseconds')); - - try { - const endAgentExecutionTimer = agentExecutionTime.labels({ framework: Version }).startTimer(); - const [agentRun, agent] = createAgentRun( - run, - { llm, tools, memory }, - { - signal: AbortSignal.any([cancellationController.signal, expirationSignal]), - ctx: context - } + await withPublisher(run, async (publish) => { + await publish({ event: 'thread.run.in_progress', data: toRunDto(run) }); + const context = { run, publish } as AgentContext; + + const tools = await getTools(run, context); + const llm = createChatLLM(run); + const memory = new TokenMemory({ llm }); + await memory.addMany(messages); + + const cancellationController = new AbortController(); + const unsub = watchForCancellation(Run, run, () => cancellationController.abort()); + + const expirationSignal = AbortSignal.timeout( + dayjs(run.expiresAt).diff(dayjs(), 'milliseconds') ); - // apply observe middleware only when the observe API is enabled - if (BEE_OBSERVE_API_URL && BEE_OBSERVE_API_AUTH_KEY && agent instanceof BeeAgent) { - (agentRun as ReturnType).middleware( - createObserveConnector({ - api: { - baseUrl: BEE_OBSERVE_API_URL, - apiAuthKey: BEE_OBSERVE_API_AUTH_KEY, - ignored_keys: [ - 'apiToken', - 'apiKey', - 'cseId', - 'accessToken', - 'proxy', - 'username', - 'password' - ] - }, - cb: async (err, traceResponse) => { - if (err) { - getTraceLogger().warn({ err }, 'bee-observe API error'); - } else if (traceResponse) { - run.trace = new Trace({ id: traceResponse.result.id }); - } - } - }) + try { + const endAgentExecutionTimer = agentExecutionTime.labels({ framework: Version }).startTimer(); + const [agentRun, agent] = createAgentRun( + run, + { llm, tools, memory }, + { + signal: AbortSignal.any([cancellationController.signal, expirationSignal]), + ctx: context + } ); - } - await agentRun; + // apply observe middleware only when the observe API is enabled + if (BEE_OBSERVE_API_URL && BEE_OBSERVE_API_AUTH_KEY && agent instanceof BeeAgent) { + (agentRun as ReturnType).middleware( + createObserveConnector({ + api: { + baseUrl: BEE_OBSERVE_API_URL, + apiAuthKey: BEE_OBSERVE_API_AUTH_KEY, + ignored_keys: [ + 'apiToken', + 'apiKey', + 'cseId', + 'accessToken', + 'proxy', + 'username', + 'password' + ] + }, + cb: async (err, traceResponse) => { + if (err) { + getTraceLogger().warn({ err }, 'bee-observe API error'); + } else if (traceResponse) { + run.trace = new Trace({ id: traceResponse.result.id }); + } + } + }) + ); + } - endAgentExecutionTimer(); - run.complete(); + await agentRun; + + endAgentExecutionTimer(); + run.complete(); - await ORM.em.flush(); - await publish({ event: 'thread.run.completed', data: toRunDto(run) }); - } catch (err) { - if (expirationSignal.aborted) { - run.expire(); await ORM.em.flush(); - await publish({ event: 'thread.run.expired', data: toRunDto(run) }); - return; - } else if (cancellationController.signal.aborted) { - run.cancel(); + await publish({ event: 'thread.run.completed', data: toRunDto(run) }); + } catch (err) { + if (expirationSignal.aborted) { + run.expire(); + await ORM.em.flush(); + await publish({ event: 'thread.run.expired', data: toRunDto(run) }); + return; + } else if (cancellationController.signal.aborted) { + run.cancel(); + await ORM.em.flush(); + await publish({ event: 'thread.run.cancelled', data: toRunDto(run) }); + return; + } + + runLogger.error({ err }, 'Run execution failed'); + run.fail(APIError.from(err)); await ORM.em.flush(); - await publish({ event: 'thread.run.cancelled', data: toRunDto(run) }); + await publish({ event: 'thread.run.failed', data: toRunDto(run) }); return; + } finally { + await publish({ event: 'done', data: '[DONE]' }); + unsub(); } - - runLogger.error({ err }, 'Run execution failed'); - run.fail(APIError.from(err)); - await ORM.em.flush(); - await publish({ event: 'thread.run.failed', data: toRunDto(run) }); - return; - } finally { - await publish({ event: 'done', data: '[DONE]' }); - unsub(); - } + }); } diff --git a/src/runs/execution/tools/function.ts b/src/runs/execution/tools/function.ts index 93951e6..be8ba47 100644 --- a/src/runs/execution/tools/function.ts +++ b/src/runs/execution/tools/function.ts @@ -26,10 +26,11 @@ import { SchemaObject } from 'ajv'; import { Loaded } from '@mikro-orm/core'; import { GetRunContext } from 'bee-agent-framework/context'; import { Emitter } from 'bee-agent-framework/emitter/emitter'; +import { Redis } from 'ioredis'; import { AgentContext } from '../execute.js'; -import { createClient } from '@/redis.js'; +import { withRedisClient } from '@/redis.js'; import { Run } from '@/runs/entities/run.entity.js'; import { ORM } from '@/database.js'; import { toRunDto } from '@/runs/runs.service.js'; @@ -81,46 +82,52 @@ export class FunctionTool extends Tool const toolCall = this.options.context.toolCall; if (!(toolCall instanceof FunctionCall)) throw new Error('Invalid tool call'); - const client = createClient(); - return await new Promise((resolve, reject) => { - client.subscribe( - FunctionTool.createChannel(this.options.context.run, toolCall.id), - async (err) => { - if (err) { - reject(err); - } else { - this.options.context.run.requireAction( - new RequiredToolOutput({ - toolCalls: [...(this.options.context.run.requiredAction?.toolCalls ?? []), toolCall] - }) - ); + return await withRedisClient( + (client) => + new Promise((resolve, reject) => { + client.subscribe( + FunctionTool.createChannel(this.options.context.run, toolCall.id), + async (err) => { + if (err) { + reject(err); + } else { + this.options.context.run.requireAction( + new RequiredToolOutput({ + toolCalls: [ + ...(this.options.context.run.requiredAction?.toolCalls ?? []), + toolCall + ] + }) + ); + await ORM.em.flush(); + await this.options.context.publish({ + event: 'thread.run.requires_action', + data: toRunDto(this.options.context.run) + }); + await this.options.context.publish({ + event: 'done', + data: '[DONE]' + }); + } + } + ); + client.on('message', async (_, output) => { + this.options.context.run.submitAction(); await ORM.em.flush(); - await this.options.context.publish({ - event: 'thread.run.requires_action', - data: toRunDto(this.options.context.run) - }); - await this.options.context.publish({ - event: 'done', - data: '[DONE]' - }); - } - } - ); - client.on('message', async (_, output) => { - this.options.context.run.submitAction(); - await ORM.em.flush(); - resolve(new FunctionToolOutput(output)); - }); - run.signal.addEventListener('abort', () => { - reject(run.signal.reason); - }); - }); + resolve(new FunctionToolOutput(output)); + }); + run.signal.addEventListener('abort', () => { + reject(run.signal.reason); + }); + }) + ); } static async submit( + client: Redis, output: string, { run, toolCallId }: { run: Loaded; toolCallId: string } ) { - await createClient().publish(FunctionTool.createChannel(run, toolCallId), output); + await client.publish(FunctionTool.createChannel(run, toolCallId), output); } } diff --git a/src/runs/execution/tools/helpers.ts b/src/runs/execution/tools/helpers.ts index 5ee0784..7bce44f 100644 --- a/src/runs/execution/tools/helpers.ts +++ b/src/runs/execution/tools/helpers.ts @@ -75,13 +75,10 @@ import { File } from '@/files/entities/file.entity.js'; import { Attachment } from '@/messages/attachment.entity.js'; import { SystemResource } from '@/tools/entities/tool-resources/system-resource.entity.js'; import { createSearchTool } from '@/runs/execution/tools/search-tool'; -import { createCacheClient } from '@/redis.js'; +import { sharedRedisCacheClient } from '@/redis.js'; const searchCache: SearchToolOptions['cache'] = new RedisCache({ - client: createCacheClient({ - connectTimeout: 1000, - maxRetriesPerRequest: 1 - }), + client: sharedRedisCacheClient, keyPrefix: 'search:', ttlSeconds: 60 * 60 }); diff --git a/src/runs/jobs/runs.queue.ts b/src/runs/jobs/runs.queue.ts index 95933a8..3a364e1 100644 --- a/src/runs/jobs/runs.queue.ts +++ b/src/runs/jobs/runs.queue.ts @@ -27,7 +27,7 @@ import { createQueue } from '@/jobs/bullmq.js'; import { QueueName } from '@/jobs/constants.js'; import { LoadedRun } from '@/runs/execution/types.js'; import { waitForExtraction } from '@/files/utils/wait-for-extraction.js'; -import { createPublisher } from '@/streaming/pubsub.js'; +import { withPublisher } from '@/streaming/pubsub.js'; import { APIError, APIErrorCode } from '@/errors/error.entity.js'; const MAX_ACTIVE_RUNS_PER_USER = 5; @@ -80,9 +80,10 @@ async function jobHandler(job: Job<{ runId: string }>) { new APIError({ message: 'Internal server error', code: APIErrorCode.INTERNAL_SERVER_ERROR }) ); await ORM.em.flush(); - const publish = createPublisher(run); - await publish({ event: 'thread.run.failed', data: toRunDto(run) }); - await publish({ event: 'done', data: '[DONE]' }); + await withPublisher(run, async (publish) => { + await publish({ event: 'thread.run.failed', data: toRunDto(run) }); + await publish({ event: 'done', data: '[DONE]' }); + }); throw err; } }); diff --git a/src/runs/runs.service.ts b/src/runs/runs.service.ts index 098b869..79f1b64 100644 --- a/src/runs/runs.service.ts +++ b/src/runs/runs.service.ts @@ -15,6 +15,7 @@ */ import { FilterQuery, Loaded, QueryOrder, ref } from '@mikro-orm/core'; +import { Redis } from 'ioredis'; import { RunCreateBody, RunCreateParams, RunCreateResponse } from './dtos/run-create.js'; import { Run, RunStatus } from './entities/run.entity.js'; @@ -48,7 +49,7 @@ import { Assistant } from '@/assistants/assistant.entity.js'; import { getServiceLogger } from '@/logger.js'; import { createThread } from '@/threads/threads.service.js'; import { createPaginatedResponse, getListCursor } from '@/utils/pagination.js'; -import { createPublisher, subscribeAndForward } from '@/streaming/pubsub.js'; +import { subscribeAndForward, withPublisher } from '@/streaming/pubsub.js'; import { listenToSocketClose } from '@/utils/networking.js'; import { APIError, APIErrorCode } from '@/errors/error.entity.js'; import { toErrorDto } from '@/errors/plugin.js'; @@ -61,7 +62,7 @@ import { getRunVectorStores } from '@/runs/execution/helpers.js'; import { getUpdatedValue } from '@/utils/update.js'; import { RunStep } from '@/run-steps/entities/run-step.entity.js'; import { RunStepDetailsType } from '@/run-steps/entities/details/run-step-details.entity.js'; -import { createClient } from '@/redis.js'; +import { withRedisClient } from '@/redis.js'; import { ToolCall } from '@/tools/entities/tool-calls/tool-call.entity.js'; import { SystemTools } from '@/tools/entities/tool-calls/system-call.entity.js'; import { ensureRequestContextData } from '@/context.js'; @@ -231,23 +232,24 @@ export async function createRun({ await ORM.em.flush(); const queueAndPublish = async () => { - const publish = createPublisher(run); - try { - await publish({ event: 'thread.run.created', data: toRunDto(run) }); - await queue.add(QueueName.RUNS, { runId: run.id }, { jobId: run.id }); - await publish({ event: 'thread.run.queued', data: toRunDto(run) }); - } catch (err) { - getRunsLogger(run.id).error({ err }, 'Failed to create run job'); - run.fail( - new APIError({ - message: 'Failed to create run job', - code: APIErrorCode.INTERNAL_SERVER_ERROR - }) - ); - await ORM.em.flush(); - await publish({ event: 'thread.run.failed', data: toRunDto(run) }); - await publish({ event: 'done', data: '[DONE]' }); - } + await withPublisher(run, async (publish) => { + try { + await publish({ event: 'thread.run.created', data: toRunDto(run) }); + await queue.add(QueueName.RUNS, { runId: run.id }, { jobId: run.id }); + await publish({ event: 'thread.run.queued', data: toRunDto(run) }); + } catch (err) { + getRunsLogger(run.id).error({ err }, 'Failed to create run job'); + run.fail( + new APIError({ + message: 'Failed to create run job', + code: APIErrorCode.INTERNAL_SERVER_ERROR + }) + ); + await ORM.em.flush(); + await publish({ event: 'thread.run.failed', data: toRunDto(run) }); + await publish({ event: 'done', data: '[DONE]' }); + } + }); }; if (stream) { @@ -383,7 +385,9 @@ export async function cancelRun({ run.startCancel(); await ORM.em.flush(); - await createPublisher(run)({ event: 'thread.run.cancelling', data: toRunDto(run) }); + await withPublisher(run, (publish) => + publish({ event: 'thread.run.cancelling', data: toRunDto(run) }) + ); return toRunDto(run); } @@ -421,12 +425,12 @@ export async function submitToolOutput({ code: APIErrorCode.INVALID_INPUT }); - const submit = async () => { + const submit = async (client: Redis) => { await Promise.all( suppliedToolCalls.map(({ toolCall, output }) => { switch (toolCall.type) { case ToolType.FUNCTION: - return FunctionTool.submit(output, { run, toolCallId: toolCall.id }); + return FunctionTool.submit(client, output, { run, toolCallId: toolCall.id }); default: throw new Error('Unexpected tool call type'); } @@ -475,11 +479,10 @@ export async function submitToolApproval({ code: APIErrorCode.INVALID_INPUT }); - const redisClient = createClient(); - const submit = async () => { + const submit = async (client: Redis) => { await Promise.all( suppliedToolCalls.map(({ toolCall, approve }) => { - redisClient.publish(createApproveChannel(run, toolCall), approve.toString()); + client.publish(createApproveChannel(run, toolCall), approve.toString()); }) ); }; @@ -494,24 +497,26 @@ async function continueRun({ }: { stream: boolean | null | undefined; run: Run; - submit: () => Promise; + submit: (client: Redis) => Promise; }) { - if (stream) { - const req = ensureRequestContextData('req'); - const res = ensureRequestContextData('res'); - const controller = new AbortController(); - const unsub = listenToSocketClose(req.socket, () => controller.abort()); - try { - await subscribeAndForward(run, res, { - signal: controller.signal, - onReady: submit, - onFailed: submit - }); - } finally { - unsub(); + return await withRedisClient(async (client) => { + if (stream) { + const req = ensureRequestContextData('req'); + const res = ensureRequestContextData('res'); + const controller = new AbortController(); + const unsub = listenToSocketClose(req.socket, () => controller.abort()); + try { + await subscribeAndForward(run, res, { + signal: controller.signal, + onReady: () => submit(client), + onFailed: () => submit(client) + }); + } finally { + unsub(); + } + } else { + await submit(client); + return toRunDto(run); } - } else { - await submit(); - return toRunDto(run); - } + }); } diff --git a/src/streaming/pubsub.ts b/src/streaming/pubsub.ts index 011291c..1903022 100644 --- a/src/streaming/pubsub.ts +++ b/src/streaming/pubsub.ts @@ -16,11 +16,12 @@ import { FastifyReply } from 'fastify'; import { Loaded } from '@mikro-orm/core'; +import { Redis } from 'ioredis'; import { Event } from './dtos/event.js'; import * as sse from './sse.js'; -import { createClient } from '@/redis.js'; +import { withRedisClient } from '@/redis.js'; import { Run } from '@/runs/entities/run.entity.js'; import { getLogger } from '@/logger.js'; @@ -28,13 +29,23 @@ function createChannel(run: Loaded) { return `run:${run.id}`; } -export function createPublisher(run: Loaded) { - const client = createClient(); +function createPublisherFn(client: Redis, run: Loaded) { return async function publish(event: Event) { - await client.publish(createChannel(run), JSON.stringify(event)); + return client.publish(createChannel(run), JSON.stringify(event)); }; } +export type Publisher = ReturnType; + +export async function withPublisher( + run: Loaded, + asyncCallback: (publish: ReturnType) => Promise +) { + return await withRedisClient((client) => { + return asyncCallback(createPublisherFn(client, run)); + }); +} + export async function subscribeAndForward( run: Loaded, res: FastifyReply, @@ -48,35 +59,36 @@ export async function subscribeAndForward( onFailed: () => Promise; } ) { - const client = createClient(); - const channel = createChannel(run); - sse.init(res); - try { - await new Promise((resolve, reject) => { - client.subscribe(channel, (err) => { - if (err) { - getLogger().error({ err, channel }, 'Subscription failed'); - onFailed().catch(reject); - reject(err); - } else { - getLogger().trace({ channel }, 'Subscribed'); - onReady().catch(reject); - } - }); - client.on('message', (_, message) => { - const event = JSON.parse(message) as Event; - sse.send(res, event); - if (event.event === 'done' || event.event === 'error') { - resolve(); - } + return withRedisClient(async (client: Redis) => { + const channel = createChannel(run); + sse.init(res); + try { + await new Promise((resolve, reject) => { + client.subscribe(channel, (err) => { + if (err) { + getLogger().error({ err, channel }, 'Subscription failed'); + onFailed().catch(reject); + reject(err); + } else { + getLogger().trace({ channel }, 'Subscribed'); + onReady().catch(reject); + } + }); + client.on('message', (_, message) => { + const event = JSON.parse(message) as Event; + sse.send(res, event); + if (event.event === 'done' || event.event === 'error') { + resolve(); + } + }); + signal.addEventListener('abort', () => { + reject(signal.reason); + }); }); - signal.addEventListener('abort', () => { - reject(signal.reason); - }); - }); - } catch (err) { - sse.send(res, { event: 'error', data: err }); - } finally { - sse.end(res); - } + } catch (err) { + sse.send(res, { event: 'error', data: err }); + } finally { + sse.end(res); + } + }); } diff --git a/src/terminus.ts b/src/terminus.ts index 6dd76b0..170d7e6 100644 --- a/src/terminus.ts +++ b/src/terminus.ts @@ -21,12 +21,20 @@ import { RawServerBase } from 'fastify/types/utils.js'; import { getLogger } from './logger.js'; import { SHUTDOWN_GRACEFUL_PERIOD } from './config.js'; -import { closeAllWorkers } from '@/jobs/bullmq.js'; +import { closeAllQueues, closeAllWorkers } from '@/jobs/bullmq.js'; +import { closeAllClients } from '@/redis'; export function createTerminus(server: RawServerBase) { async function beforeShutdown() { getLogger().info('Server shutdown started...'); - closeAllWorkers(); + + const cleanupBullMqAndRedis = async () => { + await closeAllWorkers(); + await closeAllQueues(); + await closeAllClients(); + }; + cleanupBullMqAndRedis(); + return new Promise((resolve) => { setTimeout(resolve, SHUTDOWN_GRACEFUL_PERIOD); });