Skip to content

Commit

Permalink
feat(tools): add Model Context Protocol tool
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Pilar <[email protected]>
  • Loading branch information
pilartomas committed Jan 7, 2025
1 parent fb2153c commit 76c500b
Show file tree
Hide file tree
Showing 7 changed files with 383 additions and 27 deletions.
60 changes: 60 additions & 0 deletions examples/tools/mcp.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* Copyright 2025 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { MCPTool } from "bee-agent-framework/tools/mcp";
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
import { BeeAgent } from "bee-agent-framework/agents/bee/agent";
import { UnconstrainedMemory } from "bee-agent-framework/memory/unconstrainedMemory";
import { OllamaChatLLM } from "bee-agent-framework/adapters/ollama/chat";

// Create MCP Client
const client = new Client(
{
name: "test-client",
version: "1.0.0",
},
{
capabilities: {},
},
);

// Connect the client to any MCP server with tools capablity
await client.connect(
new StdioClientTransport({
command: "npx",
args: ["-y", "@modelcontextprotocol/server-everything"],
}),
);

try {
// Server usually supports several tools, use the factory for automatic discovery
const tools = await MCPTool.createTools(client);
const agent = new BeeAgent({
llm: new OllamaChatLLM(),
memory: new UnconstrainedMemory(),
tools,
});
// @modelcontextprotocol/server-everything contains "add" tool
await agent.run({ prompt: "Find out how much is 4 + 7" }).observe((emitter) => {
emitter.on("update", async ({ data, update, meta }) => {
console.log(`Agent (${update.key}) 🤖 : `, update.value);
});
});
} finally {
// Close the MCP connection
await client.close();
}
5 changes: 5 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
"@ibm-generative-ai/node-sdk": "~3.2.4",
"@langchain/community": ">=0.2.28",
"@langchain/core": ">=0.2.27",
"@modelcontextprotocol/sdk": "^1.0.4",
"@zilliz/milvus2-sdk-node": "^2.4.9",
"google-auth-library": "*",
"groq-sdk": "^0.7.0",
Expand Down Expand Up @@ -246,6 +247,9 @@
"@langchain/core": {
"optional": true
},
"@modelcontextprotocol/sdk": {
"optional": true
},
"@zilliz/milvus2-sdk-node": {
"optional": true
},
Expand Down Expand Up @@ -285,6 +289,7 @@
"@ibm-generative-ai/node-sdk": "~3.2.4",
"@langchain/community": "~0.3.17",
"@langchain/core": "~0.3.22",
"@modelcontextprotocol/sdk": "^1.0.4",
"@opentelemetry/instrumentation": "^0.56.0",
"@opentelemetry/resources": "^1.29.0",
"@opentelemetry/sdk-node": "^0.56.0",
Expand Down
79 changes: 62 additions & 17 deletions src/internals/helpers/paginate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
* limitations under the License.
*/

import { paginate, PaginateInput } from "@/internals/helpers/paginate.js";
import {
paginate,
PaginateInput,
paginateWithCursor,
PaginateWithCursorInput,
} from "@/internals/helpers/paginate.js";

describe("paginate", () => {
it.each([
const mockSetup = [
{
size: 1,
chunkSize: 1,
Expand All @@ -38,23 +43,63 @@ describe("paginate", () => {
chunkSize: 1,
items: Array(20).fill(1),
},
])("Works %#", async ({ size, items, chunkSize }) => {
const fn: PaginateInput<number>["handler"] = vi.fn().mockImplementation(async ({ offset }) => {
const chunk = items.slice(offset, offset + chunkSize);
return { done: offset + chunk.length >= items.length, data: chunk };
});
] as const;

describe("paginate", () => {
it.each(mockSetup)("Works %#", async ({ size, items, chunkSize }) => {
const fn: PaginateInput<number>["handler"] = vi
.fn()
.mockImplementation(async ({ offset }) => {
const chunk = items.slice(offset, offset + chunkSize);
return { done: offset + chunk.length >= items.length, data: chunk };
});

const results = await paginate({
size,
handler: fn,
const results = await paginate({
size,
handler: fn,
});

const maxItemsToBeRetrieved = Math.min(size, items.length);
let expectedCalls = Math.ceil(maxItemsToBeRetrieved / chunkSize);
if (expectedCalls === 0 && size > 0) {
expectedCalls = 1;
}
expect(fn).toBeCalledTimes(expectedCalls);
expect(results).toHaveLength(maxItemsToBeRetrieved);
});
});

const maxItemsToBeRetrieved = Math.min(size, items.length);
let expectedCalls = Math.ceil(maxItemsToBeRetrieved / chunkSize);
if (expectedCalls === 0 && size > 0) {
expectedCalls = 1;
}
expect(fn).toBeCalledTimes(expectedCalls);
expect(results).toHaveLength(maxItemsToBeRetrieved);
describe("paginateWithCursor", () => {
it.each(mockSetup)("Works %#", async ({ size, items, chunkSize }) => {
const fn = vi
.fn<PaginateWithCursorInput<number, number>["handler"]>()
.mockImplementation(async ({ cursor = 0 }) => {
const chunk = items.slice(cursor, cursor + chunkSize);
const isDone = cursor + chunk.length >= items.length;
return isDone
? ({
done: true,
data: chunk,
} as const)
: ({
done: false,
data: chunk,
nextCursor: cursor + chunk.length,
} as const);
});

const results = await paginateWithCursor({
size,
handler: fn,
});

const maxItemsToBeRetrieved = Math.min(size, items.length);
let expectedCalls = Math.ceil(maxItemsToBeRetrieved / chunkSize);
if (expectedCalls === 0 && size > 0) {
expectedCalls = 1;
}
expect(fn).toBeCalledTimes(expectedCalls);
expect(results).toHaveLength(maxItemsToBeRetrieved);
});
});
});
32 changes: 32 additions & 0 deletions src/internals/helpers/paginate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,35 @@ export async function paginate<T>(input: PaginateInput<T>): Promise<T[]> {

return acc;
}

export interface PaginateWithCursorInput<T, C> {
size: number;
handler: (data: {
cursor: C | undefined;
limit: number;
}) => Promise<{ data: T[]; done: true } | { data: T[]; done: false; nextCursor: C }>;
}

export async function paginateWithCursor<T, C>(input: PaginateWithCursorInput<T, C>): Promise<T[]> {
const acc: T[] = [];
let cursor: C | undefined;
while (acc.length < input.size) {
const result = await input.handler({
cursor,
limit: input.size - acc.length,
});
acc.push(...result.data);

if (result.done || result.data.length === 0) {
break;
} else {
cursor = result.nextCursor;
}
}

if (acc.length > input.size) {
acc.length = input.size;
}

return acc;
}
103 changes: 103 additions & 0 deletions src/tools/mcp.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/**
* Copyright 2025 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import { CallToolRequestSchema, ListToolsRequestSchema } from "@modelcontextprotocol/sdk/types.js";
import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { MCPTool } from "./mcp.js";
import { entries } from "remeda";
import { zodToJsonSchema } from "zod-to-json-schema";
import { z } from "zod";

const abInputSchema = z.object({ a: z.number(), b: z.number() });
const toolDescriptions = {
add: {
description: "Adds two numbers",
inputSchema: zodToJsonSchema(abInputSchema),
handler: ({ a, b }: z.input<typeof abInputSchema>) => a + b,
},
multiply: {
description: "Multiplies two numbers",
inputSchema: zodToJsonSchema(abInputSchema),
handler: ({ a, b }: z.input<typeof abInputSchema>) => a * b,
},
} as const;

describe("MCPTool", () => {
let server: Server;
let client: Client;
let tools: MCPTool[];

beforeEach(async () => {
server = new Server(
{
name: "test-server",
version: "1.0.0",
},
{
capabilities: {
tools: {},
},
},
);
server.setRequestHandler(ListToolsRequestSchema, async () => {
return {
tools: entries(toolDescriptions).map(([name, { description, inputSchema }]) => ({
name,
description,
inputSchema,
})),
};
});
server.setRequestHandler(CallToolRequestSchema, async (request) => {
const tool = toolDescriptions[request.params.name as keyof typeof toolDescriptions];
if (!tool) {
throw new Error("Tool not found");
}
// Arguments are assumed to be valid in this mock
return {
contents: [tool.handler(request.params.arguments as any)],
};
});

client = new Client(
{
name: "test-client",
version: "1.0.0",
},
{
capabilities: {},
},
);

const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);
await client.connect(clientTransport);

tools = await MCPTool.createTools(client);
});

it("should run the tools", async () => {
const tool = tools.at(0);
expect(tool).toBeDefined();
});

afterEach(async () => {
await client.close();
await server.close();
});
});
70 changes: 70 additions & 0 deletions src/tools/mcp.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/**
* Copyright 2025 IBM Corp.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { BaseToolRunOptions, ToolEmitter, ToolInput, JSONToolOutput, Tool } from "@/tools/base.js";
import { Emitter } from "@/emitter/emitter.js";
import { GetRunContext } from "@/context.js";
import { Client as MCPClient } from "@modelcontextprotocol/sdk/client/index.js";
import { ListToolsResult } from "@modelcontextprotocol/sdk/types.js";
import { SchemaObject } from "ajv";

export interface MCPToolInput {
client: MCPClient;
tool: ListToolsResult["tools"][number];
}

export class MCPToolOutput extends JSONToolOutput<any> {}

export class MCPTool extends Tool<MCPToolOutput> {
public readonly name: string;
public readonly description: string;

public readonly client: MCPClient;
private readonly tool: ListToolsResult["tools"][number];

constructor({ client, tool, ...options }: MCPToolInput) {
super(options);
this.client = client;
this.tool = tool;
this.name = tool.name;
this.description = tool.description ?? "No description, use based on name.";
}

public readonly emitter: ToolEmitter<ToolInput<this>, MCPToolOutput> = Emitter.root.child({
namespace: ["tool", "mcp", "tool"],
creator: this,
});

inputSchema() {
return this.tool.inputSchema as SchemaObject;
}

protected async _run(
input: ToolInput<this>,
_options: BaseToolRunOptions,
run: GetRunContext<typeof this>,
) {
const result = await this.client.callTool({ name: this.name, arguments: input }, undefined, {
signal: run.signal,
});
return new MCPToolOutput(result);
}

static async createTools(client: MCPClient): Promise<MCPTool[]> {
const { tools } = await client.listTools();
return tools.map((tool) => new MCPTool({ client, tool }));
}
}
Loading

0 comments on commit 76c500b

Please sign in to comment.