Skip to content

Commit

Permalink
feat(deps): upgrade bee-agent-framework and refactor (#102)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D authored Dec 2, 2024
1 parent 989bc48 commit fc014df
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 114 deletions.
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
"@zilliz/milvus2-sdk-node": "^2.4.4",
"ajv": "^8.17.1",
"axios": "^1.7.7",
"bee-agent-framework": "0.0.42",
"bee-observe-connector": "0.0.5",
"bee-agent-framework": "0.0.44",
"bee-observe-connector": "0.0.6",
"bullmq": "5.8.1",
"cache-manager": "^5.7.6",
"dayjs": "^1.11.11",
Expand Down
41 changes: 26 additions & 15 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

93 changes: 46 additions & 47 deletions src/files/extraction/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import { Loaded } from '@mikro-orm/core';
import mime from 'mime';
import { recursiveSplitString } from 'bee-agent-framework/internals/helpers/string';
import ibm from 'ibm-cos-sdk';

import { s3Client } from '../files.service';
import { DoclingExtraction } from '../entities/extractions/docling-extraction.entity';
Expand All @@ -34,6 +35,12 @@ import { EXTRACTION_BACKEND, S3_BUCKET_FILE_STORAGE } from '@/config';
import { ORM } from '@/database';
import { QueueName } from '@/jobs/constants';

export const withAbort = <A, B>(value: ibm.Request<A, B>, signal?: AbortSignal) => {
const handler = () => value.abort();
signal?.addEventListener('abort', handler);
return value.promise().finally(() => signal?.removeEventListener('abort', handler));
};

function isNativeDoclingFormat(mimeType: string): boolean {
const extension = mime.getExtension(mimeType);
if (!extension) return false;
Expand Down Expand Up @@ -200,81 +207,72 @@ export async function scheduleExtraction(
}
}

export async function removeExtraction(file: Loaded<File>) {
type AvailableKeys<T> = Exclude<T extends T ? keyof T : never, keyof unknown[]>;

const keyByProvider = {
[ExtractionBackend.DOCLING]: ['documentStorageId', 'chunksStorageId', 'textStorageId'],
[ExtractionBackend.WDU]: ['storageId'],
[ExtractionBackend.UNSTRUCTURED_OPENSOURCE]: ['storageId'],
[ExtractionBackend.UNSTRUCTURED_API]: ['storageId']
} as const satisfies Record<ExtractionBackend, AvailableKeys<typeof File.prototype.extraction>[]>;

export async function removeExtraction(file: Loaded<File>, signal?: AbortSignal) {
const extraction = file.extraction;
if (!extraction) throw new Error('No extraction to remove');
switch (extraction.backend) {
case ExtractionBackend.DOCLING:
if (extraction.documentStorageId)
await s3Client
.deleteObject({ Bucket: S3_BUCKET_FILE_STORAGE, Key: extraction.documentStorageId })
.promise();
if (extraction.chunksStorageId)
await s3Client
.deleteObject({ Bucket: S3_BUCKET_FILE_STORAGE, Key: extraction.chunksStorageId })
.promise();
if (extraction.textStorageId)
await s3Client
.deleteObject({ Bucket: S3_BUCKET_FILE_STORAGE, Key: extraction.textStorageId })
.promise();
break;
case ExtractionBackend.WDU:
if (extraction.storageId)
await s3Client
.deleteObject({ Bucket: S3_BUCKET_FILE_STORAGE, Key: extraction.storageId })
.promise();
break;
case ExtractionBackend.UNSTRUCTURED_OPENSOURCE:
case ExtractionBackend.UNSTRUCTURED_API:
if (extraction.storageId)
await s3Client
.deleteObject({ Bucket: S3_BUCKET_FILE_STORAGE, Key: extraction.storageId })
.promise();
break;
}

await Promise.all(
keyByProvider[extraction.backend].map(async (property) => {
const Key = extraction[property as keyof typeof extraction];
if (Key) {
await withAbort(s3Client.deleteObject({ Bucket: S3_BUCKET_FILE_STORAGE, Key }), signal);
}
})
);

file.extraction = undefined;
await ORM.em.flush();
}

export async function getExtractedText(file: Loaded<File>) {
export async function getExtractedText(file: Loaded<File>, signal?: AbortSignal): Promise<string> {
const extraction = file.extraction;
if (!extraction) throw new Error('Extraction not found');
switch (extraction.backend) {
case ExtractionBackend.WDU: {
if (!extraction.storageId) throw new Error('Extraction missing');
const object = await s3Client
.getObject({
const object = await withAbort(
s3Client.getObject({
Bucket: S3_BUCKET_FILE_STORAGE,
Key: extraction.storageId
})
.promise();
}),
signal
);
const body = object.Body;
if (!body) throw new Error('Invalid Body of a file');
return body.toString('utf-8');
}
case ExtractionBackend.DOCLING: {
if (!extraction.textStorageId) throw new Error('Extraction missing');
return readTextFile(extraction.textStorageId);
return readTextFile(extraction.textStorageId, signal);
}
case ExtractionBackend.UNSTRUCTURED_OPENSOURCE:
case ExtractionBackend.UNSTRUCTURED_API: {
if (!extraction.storageId) throw new Error('Extraction missing');
const elements = JSON.parse(
await readTextFile(extraction.storageId)
await readTextFile(extraction.storageId, signal)
) as UnstructuredExtractionDocument;
return elements.map((element) => element.text).join('');
}
}
}

export async function getExtractedChunks(file: Loaded<File>) {
export async function getExtractedChunks(file: Loaded<File>, signal?: AbortSignal) {
const extraction = file.extraction;
if (!extraction) throw new Error('Extraction not found');
switch (extraction.backend) {
case ExtractionBackend.DOCLING: {
if (!extraction.chunksStorageId) {
if (!extraction.textStorageId) throw new Error('Extraction missing');
const text = await getExtractedText(file);
const text = await getExtractedText(file, signal);
const splitter = recursiveSplitString(text, {
size: 400,
overlap: 200,
Expand All @@ -283,12 +281,12 @@ export async function getExtractedChunks(file: Loaded<File>) {
return Array.from(splitter);
}
const chunks = JSON.parse(
await readTextFile(extraction.chunksStorageId)
await readTextFile(extraction.chunksStorageId, signal)
) as DoclingChunksExtraction;
return chunks.map((c) => c.text);
}
case ExtractionBackend.WDU: {
const text = await getExtractedText(file);
const text = await getExtractedText(file, signal);
const splitter = recursiveSplitString(text, {
size: 400,
overlap: 200,
Expand All @@ -300,7 +298,7 @@ export async function getExtractedChunks(file: Loaded<File>) {
case ExtractionBackend.UNSTRUCTURED_API: {
if (!extraction.storageId) throw new Error('Extraction missing');
const elements = JSON.parse(
await readTextFile(extraction.storageId)
await readTextFile(extraction.storageId, signal)
) as UnstructuredExtractionDocument;
return elements
.filter((element) => element.type === 'CompositeElement')
Expand All @@ -309,13 +307,14 @@ export async function getExtractedChunks(file: Loaded<File>) {
}
}

async function readTextFile(key: string) {
const object = await s3Client
.getObject({
async function readTextFile(key: string, signal?: AbortSignal) {
const object = await withAbort(
s3Client.getObject({
Bucket: S3_BUCKET_FILE_STORAGE,
Key: key
})
.promise();
}),
signal
);
const body = object.Body;
if (!body) throw new Error('Invalid Body of a file');
const data = body.toString('utf-8');
Expand Down
27 changes: 12 additions & 15 deletions src/runs/execution/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@ import { WatsonXLLM } from 'bee-agent-framework/adapters/watsonx/llm';
import { ZodType } from 'zod';
import { PromptTemplate } from 'bee-agent-framework';
import { AnyTool } from 'bee-agent-framework/tools/base';
import { GraniteBeeAgent } from 'bee-agent-framework/agents/granite/agent';
import { StreamlitAgent } from 'bee-agent-framework/agents/experimental/streamlit/agent';
import { GraniteBeeSystemPrompt } from 'bee-agent-framework/agents/granite/prompts';
import { BeeAgent } from 'bee-agent-framework/agents/bee/agent';
import { BeeSystemPrompt } from 'bee-agent-framework/agents/bee/prompts';
import { ChatLLM, ChatLLMOutput } from 'bee-agent-framework/llms/chat';
import { BaseMemory } from 'bee-agent-framework/memory/base';
import { StreamlitAgentSystemPrompt } from 'bee-agent-framework/agents/experimental/streamlit/prompts';
import { GraniteBeeSystemPrompt } from 'bee-agent-framework/agents/bee/runners/granite/prompts';

import { Run } from '../entities/run.entity';

Expand Down Expand Up @@ -252,19 +251,17 @@ export function createAgentRun(
] as const;
switch (run.assistant.$.agent) {
case Agent.BEE: {
const agent = run.model.includes('granite')
? new GraniteBeeAgent({
llm,
memory,
tools,
templates: { system: getPromptTemplate(run, GraniteBeeSystemPrompt) }
})
: new BeeAgent({
llm,
memory,
tools,
templates: { system: getPromptTemplate(run, BeeSystemPrompt) }
});
const agent = new BeeAgent({
llm,
memory,
tools,
templates: {
system: getPromptTemplate(
run,
run.model.includes('granite') ? GraniteBeeSystemPrompt : BeeSystemPrompt
)
}
});
return [agent.run(...runArgs).observe(createBeeStreamingHandler(ctx)), agent];
}
case Agent.STREAMLIT: {
Expand Down
Loading

0 comments on commit fc014df

Please sign in to comment.