From 429a73f89b427e0b9477fa611b5c11140f63f662 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 12 Nov 2024 15:02:39 +0100 Subject: [PATCH] Initial ComponentTool --- haystack/tools/__init__.py | 4 +- haystack/tools/component_tool.py | 298 ++++++++++++++++ pyproject.toml | 1 + test/tools/test_component_tool.py | 549 ++++++++++++++++++++++++++++++ 4 files changed, 851 insertions(+), 1 deletion(-) create mode 100644 haystack/tools/component_tool.py create mode 100644 test/tools/test_component_tool.py diff --git a/haystack/tools/__init__.py b/haystack/tools/__init__.py index 9cd887f4e2..a1871c0f53 100644 --- a/haystack/tools/__init__.py +++ b/haystack/tools/__init__.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 + +from haystack.tools.component_tool import ComponentTool from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace -__all__ = ["Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace"] +__all__ = ["ComponentTool", "Tool", "_check_duplicate_tool_names", "deserialize_tools_inplace"] diff --git a/haystack/tools/component_tool.py b/haystack/tools/component_tool.py new file mode 100644 index 0000000000..a7f4f336b0 --- /dev/null +++ b/haystack/tools/component_tool.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import fields, is_dataclass +from inspect import getdoc +from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin + +from docstring_parser import parse +from pydantic import TypeAdapter + +from haystack import logging +from haystack.core.component import Component +from haystack.tools import Tool + +logger = logging.getLogger(__name__) + + +def is_pydantic_v2_model(instance: Any) -> bool: + """ + Checks if the instance is a Pydantic v2 model. + + :param instance: The instance to check. + :returns: True if the instance is a Pydantic v2 model, False otherwise. + """ + return hasattr(instance, "model_validate") + + +class ComponentTool(Tool): + """ + A Tool that wraps Haystack components, allowing them to be used as tools by LLMs. + + ComponentTool automatically generates OpenAI-compatible function schemas from Component input sockets, + which are derived from the component's `run` method signature and type hints. + + + Key features: + - Automatic LLM tool calling schema generation from Component input sockets + - Type conversion and validation for Component inputs + - Support for complex types (dataclasses, Pydantic models, lists) + - Automatic name generation from Component class name + - Description extraction from Component docstrings + + Let's assume we already have (or want to create) a Haystack component that we want to use as a tool. + We can create a ComponentTool from the component by passing the component to the ComponentTool constructor. + + ```python + from haystack import component + from haystack.tools import ComponentTool + + @component + class WeatherComponent: + '''Gets weather information for a location.''' + + def run(self, city: str, units: str = "celsius"): + ''' + :param city: The city to get weather for + :param units: Temperature units (celsius/fahrenheit) + ''' + return f"Weather in {city}: 20°{units}" + + # Create a tool from the component + weather = WeatherComponent() + tool = ComponentTool( + component=weather, + name="get_weather", # Optional: defaults to snake_case of class name + description="Get current weather for a city" # Optional: defaults to component run method docstring + ) + ``` + + """ + + def __init__(self, component: Component, name: Optional[str] = None, description: Optional[str] = None): + """ + Create a Tool instance from a Haystack component. + + :param component: The Haystack Component to wrap as a tool + :param name: Optional name for the tool (defaults to snake_case of Component class name) + :param description: Optional description (defaults to Component's docstring) + :raises ValueError: If the component is invalid or schema generation fails + """ + if not isinstance(component, Component): + message = ( + f"Object {component!r} is not a Haystack component. " + "Use this method to create a Tool only with Haystack component instances." + ) + raise ValueError(message) + + # Create the tools schema from the component run method parameters + tool_schema = self._create_tool_parameters_schema(component) + + def component_invoker(**kwargs): + """ + Invokes the component using keyword arguments provided by the LLM function calling/tool generated response. + + :param kwargs: The keyword arguments to invoke the component with. + :returns: The result of the component invocation. + """ + converted_kwargs = {} + input_sockets = component.__haystack_input__._sockets_dict + for param_name, param_value in kwargs.items(): + param_type = input_sockets[param_name].type + + # Check if the type (or list element type) has from_dict + target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type + if hasattr(target_type, "from_dict"): + if isinstance(param_value, list): + param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)] + elif isinstance(param_value, dict): + param_value = target_type.from_dict(param_value) + else: + # Let TypeAdapter handle both single values and lists + type_adapter = TypeAdapter(param_type) + param_value = type_adapter.validate_python(param_value) + + converted_kwargs[param_name] = param_value + logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") + return component.run(**converted_kwargs) + + # Generate a name for the tool if not provided + if not name: + class_name = component.__class__.__name__ + # Convert camelCase/PascalCase to snake_case + name = "".join( + [ + "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower() + for i, c in enumerate(class_name) + ] + ).lstrip("_") + + # Generate a description for the tool if not provided and truncate to 512 characters + # as most LLMs have a limit for the description length + description = (description or component.__doc__ or name)[:512] + + # Create the Tool instance with the component invoker as the function to be called and the schema + super().__init__(name, description, tool_schema, component_invoker) + self.component = component + + def _create_tool_parameters_schema(self, component: Component) -> Dict[str, Any]: + """ + Creates an OpenAI tools schema from a component's run method parameters. + + :param component: The component to create the schema from. + :returns: OpenAI tools schema for the component's run method parameters. + """ + properties = {} + required = [] + + param_descriptions = self._get_param_descriptions(component.run) + + for input_name, socket in component.__haystack_input__._sockets_dict.items(): + input_type = socket.type + description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.") + + try: + property_schema = self._create_property_schema(input_type, description) + except ValueError as e: + raise ValueError(f"Error processing input '{input_name}': {e}") + + properties[input_name] = property_schema + + # Use socket.is_mandatory to check if the input is required + if socket.is_mandatory: + required.append(input_name) + + parameters_schema = {"type": "object", "properties": properties} + + if required: + parameters_schema["required"] = required + + return parameters_schema + + def _get_param_descriptions(self, method: Callable) -> Dict[str, str]: + """ + Extracts parameter descriptions from the method's docstring using docstring_parser. + + :param method: The method to extract parameter descriptions from. + :returns: A dictionary mapping parameter names to their descriptions. + """ + docstring = getdoc(method) + if not docstring: + return {} + + parsed_doc = parse(docstring) + param_descriptions = {} + for param in parsed_doc.params: + if not param.description: + logger.warning( + "Missing description for parameter '%s'. Please add a description in the component's " + "run() method docstring using the format ':param %s: '. " + "This description helps the LLM understand how to use this parameter.", + param.arg_name, + param.arg_name, + ) + param_descriptions[param.arg_name] = param.description.strip() if param.description else "" + return param_descriptions + + def _is_nullable_type(self, python_type: Any) -> bool: + """ + Checks if the type is a Union with NoneType (i.e., Optional). + + :param python_type: The Python type to check. + :returns: True if the type is a Union with NoneType, False otherwise. + """ + origin = get_origin(python_type) + if origin is Union: + return type(None) in get_args(python_type) + return False + + def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a list type. + + :param item_type: The type of items in the list. + :param description: The description of the list. + :returns: A dictionary representing the list schema. + """ + items_schema = self._create_property_schema(item_type, "") + items_schema.pop("description", None) + return {"type": "array", "description": description, "items": items_schema} + + def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a dataclass. + + :param python_type: The dataclass type. + :param description: The description of the dataclass. + :returns: A dictionary representing the dataclass schema. + """ + schema = {"type": "object", "description": description, "properties": {}} + cls = python_type if isinstance(python_type, type) else python_type.__class__ + for field in fields(cls): + field_description = f"Field '{field.name}' of '{cls.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][field.name] = self._create_property_schema(field.type, field_description) + return schema + + def _create_pydantic_schema(self, python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a Pydantic model. + + :param python_type: The Pydantic model type. + :param description: The description of the model. + :returns: A dictionary representing the Pydantic model schema. + """ + schema = {"type": "object", "description": description, "properties": {}} + required_fields = [] + + for m_name, m_field in python_type.model_fields.items(): + field_description = f"Field '{m_name}' of '{python_type.__name__}'." + if isinstance(schema["properties"], dict): + schema["properties"][m_name] = self._create_property_schema(m_field.annotation, field_description) + if m_field.is_required(): + required_fields.append(m_name) + + if required_fields: + schema["required"] = required_fields + return schema + + def _create_basic_type_schema(self, python_type: Any, description: str) -> Dict[str, Any]: + """ + Creates a schema for a basic Python type. + + :param python_type: The Python type. + :param description: The description of the type. + :returns: A dictionary representing the basic type schema. + """ + type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"} + return {"type": type_mapping.get(python_type, "string"), "description": description} + + def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]: + """ + Creates a property schema for a given Python type, recursively if necessary. + + :param python_type: The Python type to create a property schema for. + :param description: The description of the property. + :param default: The default value of the property. + :returns: A dictionary representing the property schema. + """ + nullable = self._is_nullable_type(python_type) + if nullable: + non_none_types = [t for t in get_args(python_type) if t is not type(None)] + python_type = non_none_types[0] if non_none_types else str + + origin = get_origin(python_type) + if origin is list: + schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description) + elif is_dataclass(python_type): + schema = self._create_dataclass_schema(python_type, description) + elif is_pydantic_v2_model(python_type): + schema = self._create_pydantic_schema(python_type, description) + else: + schema = self._create_basic_type_schema(python_type, description) + + if default is not None: + schema["default"] = default + + return schema diff --git a/pyproject.toml b/pyproject.toml index 73031b8130..2339c4e00d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dependencies = [ "requests", "numpy", "python-dateutil", + "docstring-parser", "haystack-experimental", ] diff --git a/test/tools/test_component_tool.py b/test/tools/test_component_tool.py new file mode 100644 index 0000000000..e91c27541a --- /dev/null +++ b/test/tools/test_component_tool.py @@ -0,0 +1,549 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os +import pytest +from typing import Dict, List, Optional, Any +from dataclasses import dataclass +from haystack import component +from pydantic import BaseModel +from haystack import Pipeline +from haystack.dataclasses import Document +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.components.tools.tool_invoker import ToolInvoker +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.tools import ComponentTool + + +### Component and Model Definitions + + +@component +class SimpleComponent: + """A simple component that generates text.""" + + @component.output_types(reply=str) + def run(self, text: str) -> Dict[str, str]: + """ + A simple component that generates text. + + :param text: user's name + :return: A dictionary with the generated text. + """ + return {"reply": f"Hello, {text}!"} + + +class Product(BaseModel): + """A product model.""" + + name: str + price: float + + +@dataclass +class User: + """A simple user dataclass.""" + + name: str = "Anonymous" + age: int = 0 + + +@component +class UserGreeter: + """A simple component that processes a User.""" + + @component.output_types(message=str) + def run(self, user: User) -> Dict[str, str]: + """ + A simple component that processes a User. + + :param user: The User object to process. + :return: A dictionary with a message about the user. + """ + return {"message": f"User {user.name} is {user.age} years old"} + + +@component +class ListProcessor: + """A component that processes a list of strings.""" + + @component.output_types(concatenated=str) + def run(self, texts: List[str]) -> Dict[str, str]: + """ + Concatenates a list of strings into a single string. + + :param texts: The list of strings to concatenate. + :return: A dictionary with the concatenated string. + """ + return {"concatenated": " ".join(texts)} + + +@component +class ProductProcessor: + """A component that processes a Product.""" + + @component.output_types(description=str) + def run(self, product: Product) -> Dict[str, str]: + """ + Creates a description for the product. + + :param product: The Product to process. + :return: A dictionary with the product description. + """ + return {"description": f"The product {product.name} costs ${product.price:.2f}."} + + +@dataclass +class Address: + """A dataclass representing a physical address.""" + + street: str + city: str + + +@dataclass +class Person: + """A person with an address.""" + + name: str + address: Address + + +@component +class PersonProcessor: + """A component that processes a Person with nested Address.""" + + @component.output_types(info=str) + def run(self, person: Person) -> Dict[str, str]: + """ + Creates information about the person. + + :param person: The Person to process. + :return: A dictionary with the person's information. + """ + return {"info": f"{person.name} lives at {person.address.street}, {person.address.city}."} + + +@component +class DocumentProcessor: + """A component that processes a list of Documents.""" + + @component.output_types(concatenated=str) + def run(self, documents: List[Document], top_k: int = 5) -> Dict[str, str]: + """ + Concatenates the content of multiple documents with newlines. + + :param documents: List of Documents whose content will be concatenated + :param top_k: The number of top documents to concatenate + :returns: Dictionary containing the concatenated document contents + """ + return {"concatenated": "\n".join(doc.content for doc in documents[:top_k])} + + +## Unit tests +class TestToolComponent: + def test_from_component_basic(self): + component = SimpleComponent() + + tool = ComponentTool(component=component) + + assert tool.name == "simple_component" + assert tool.description == "A simple component that generates text." + assert tool.parameters == { + "type": "object", + "properties": {"text": {"type": "string", "description": "user's name"}}, + "required": ["text"], + } + + # Test tool invocation + result = tool.invoke(text="world") + assert isinstance(result, dict) + assert "reply" in result + assert result["reply"] == "Hello, world!" + + def test_from_component_with_dataclass(self): + component = UserGreeter() + + tool = ComponentTool(component=component) + assert tool.parameters == { + "type": "object", + "properties": { + "user": { + "type": "object", + "description": "The User object to process.", + "properties": { + "name": {"type": "string", "description": "Field 'name' of 'User'."}, + "age": {"type": "integer", "description": "Field 'age' of 'User'."}, + }, + } + }, + "required": ["user"], + } + + assert tool.name == "user_greeter" + assert tool.description == "A simple component that processes a User." + + # Test tool invocation + result = tool.invoke(user={"name": "Alice", "age": 30}) + assert isinstance(result, dict) + assert "message" in result + assert result["message"] == "User Alice is 30 years old" + + def test_from_component_with_list_input(self): + component = ListProcessor() + + tool = ComponentTool( + component=component, name="list_processing_tool", description="A tool that concatenates strings" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "texts": { + "type": "array", + "description": "The list of strings to concatenate.", + "items": {"type": "string"}, + } + }, + "required": ["texts"], + } + + # Test tool invocation + result = tool.invoke(texts=["hello", "world"]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "hello world" + + def test_from_component_with_pydantic_model(self): + component = ProductProcessor() + + tool = ComponentTool(component=component, name="product_tool", description="A tool that processes products") + + assert tool.parameters == { + "type": "object", + "properties": { + "product": { + "type": "object", + "description": "The Product to process.", + "properties": { + "name": {"type": "string", "description": "Field 'name' of 'Product'."}, + "price": {"type": "number", "description": "Field 'price' of 'Product'."}, + }, + "required": ["name", "price"], + } + }, + "required": ["product"], + } + + # Test tool invocation + result = tool.invoke(product={"name": "Widget", "price": 19.99}) + assert isinstance(result, dict) + assert "description" in result + assert result["description"] == "The product Widget costs $19.99." + + def test_from_component_with_nested_dataclass(self): + component = PersonProcessor() + + tool = ComponentTool(component=component, name="person_tool", description="A tool that processes people") + + assert tool.parameters == { + "type": "object", + "properties": { + "person": { + "type": "object", + "description": "The Person to process.", + "properties": { + "name": {"type": "string", "description": "Field 'name' of 'Person'."}, + "address": { + "type": "object", + "description": "Field 'address' of 'Person'.", + "properties": { + "street": {"type": "string", "description": "Field 'street' of 'Address'."}, + "city": {"type": "string", "description": "Field 'city' of 'Address'."}, + }, + }, + }, + } + }, + "required": ["person"], + } + + # Test tool invocation + result = tool.invoke(person={"name": "Diana", "address": {"street": "123 Elm Street", "city": "Metropolis"}}) + assert isinstance(result, dict) + assert "info" in result + assert result["info"] == "Diana lives at 123 Elm Street, Metropolis." + + def test_from_component_with_document_list(self): + component = DocumentProcessor() + + tool = ComponentTool( + component=component, name="document_processor", description="A tool that concatenates document contents" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "documents": { + "type": "array", + "description": "List of Documents whose content will be concatenated", + "items": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "Field 'id' of 'Document'."}, + "content": {"type": "string", "description": "Field 'content' of 'Document'."}, + "dataframe": {"type": "string", "description": "Field 'dataframe' of 'Document'."}, + "blob": { + "type": "object", + "description": "Field 'blob' of 'Document'.", + "properties": { + "data": {"type": "string", "description": "Field 'data' of 'ByteStream'."}, + "meta": {"type": "string", "description": "Field 'meta' of 'ByteStream'."}, + "mime_type": { + "type": "string", + "description": "Field 'mime_type' of 'ByteStream'.", + }, + }, + }, + "meta": {"type": "string", "description": "Field 'meta' of 'Document'."}, + "score": {"type": "number", "description": "Field 'score' of 'Document'."}, + "embedding": { + "type": "array", + "description": "Field 'embedding' of 'Document'.", + "items": {"type": "number"}, + }, + "sparse_embedding": { + "type": "object", + "description": "Field 'sparse_embedding' of 'Document'.", + "properties": { + "indices": { + "type": "array", + "description": "Field 'indices' of 'SparseEmbedding'.", + "items": {"type": "integer"}, + }, + "values": { + "type": "array", + "description": "Field 'values' of 'SparseEmbedding'.", + "items": {"type": "number"}, + }, + }, + }, + }, + }, + }, + "top_k": {"description": "The number of top documents to concatenate", "type": "integer"}, + }, + "required": ["documents"], + } + + # Test tool invocation + result = tool.invoke(documents=[{"content": "First document"}, {"content": "Second document"}]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "First document\nSecond document" + + def test_from_component_with_non_component(self): + class NotAComponent: + def foo(self, text: str): + return {"reply": f"Hello, {text}!"} + + not_a_component = NotAComponent() + + with pytest.raises(ValueError): + ComponentTool(component=not_a_component, name="invalid_tool", description="This should fail") + + +## Integration tests +class TestToolComponentInPipelineWithOpenAI: + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_component_tool_in_pipeline(self): + # Create component and convert it to tool + component = SimpleComponent() + tool = ComponentTool( + component=component, name="hello_tool", description="A tool that generates a greeting message for the user" + ) + + # Create pipeline with OpenAIChatGenerator and ToolInvoker + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + + # Connect components + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Vladimir") + + # Run pipeline + result = pipeline.run({"llm": {"messages": [message]}}) + + # Check results + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Vladimir" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_user_greeter_in_pipeline(self): + component = UserGreeter() + tool = ComponentTool( + component=component, name="user_greeter", description="A tool that greets users with their name and age" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="I am Alice and I'm 30 years old") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"message": "User Alice is 30 years old"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_list_processor_in_pipeline(self): + component = ListProcessor() + tool = ComponentTool( + component=component, name="list_processor", description="A tool that concatenates a list of strings" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you join these words: hello, beautiful, world") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"concatenated": "hello beautiful world"}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_product_processor_in_pipeline(self): + component = ProductProcessor() + tool = ComponentTool( + component=component, + name="product_processor", + description="A tool that creates a description for a product with its name and price", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Can you describe a product called Widget that costs $19.99?") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert tool_message.tool_call_result.result == str({"description": "The product Widget costs $19.99."}) + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_person_processor_in_pipeline(self): + component = PersonProcessor() + tool = ComponentTool( + component=component, + name="person_processor", + description="A tool that processes information about a person and their address", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user(text="Diana lives at 123 Elm Street in Metropolis") + + result = pipeline.run({"llm": {"messages": [message]}}) + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_document_processor_in_pipeline(self): + component = DocumentProcessor() + tool = ComponentTool( + component=component, + name="document_processor", + description="A tool that concatenates the content of multiple documents", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool], convert_result_to_json_string=True)) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="Concatenate these documents: First one says 'Hello world' and second one says 'Goodbye world' and third one says 'Hello again', but use top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + result = json.loads(tool_message.tool_call_result.result) + assert "concatenated" in result + assert "Hello world" in result["concatenated"] + assert "Goodbye world" in result["concatenated"] + assert not tool_message.tool_call_result.error + + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_lost_in_middle_ranker_in_pipeline(self): + from haystack.components.rankers import LostInTheMiddleRanker + + component = LostInTheMiddleRanker() + tool = ComponentTool( + component=component, + name="lost_in_middle_ranker", + description="A tool that ranks documents using the Lost in the Middle algorithm and returns top k results", + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool])) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="I have three documents with content: 'First doc', 'Middle doc', and 'Last doc'. Rank them top_k=2. Set only content field of the document only. Do not set id, meta, score, embedding, sparse_embedding, dataframe, blob fields." + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL)