From ccfcfe5322fd05d2ea72bb511145fc109a210ba3 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 20 Jan 2025 17:01:34 +0100 Subject: [PATCH 01/10] AmazonBedrockChatGenerator - add tools support --- integrations/amazon_bedrock/pyproject.toml | 1 + .../amazon_bedrock/chat/chat_generator.py | 162 +++++++++-- .../tests/test_chat_generator.py | 274 +++++++++++++----- 3 files changed, 338 insertions(+), 99 deletions(-) diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index ad9754a33..a119a14cb 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bcf11414c..162b9a6ce 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,12 +1,13 @@ import json import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple from botocore.config import Config from botocore.eventstream import EventStream from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -19,6 +20,81 @@ logger = logging.getLogger(__name__) +def _convert_tools_to_bedrock_format(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]: + """ + Convert Haystack Tool(s) to Amazon Bedrock toolConfig format. + + :param tools: List of Tool objects to convert + :return: Dictionary in Bedrock toolConfig format or None if no tools + """ + if not tools: + return None + + tool_specs = [] + for tool in tools: + tool_specs.append( + {"toolSpec": {"name": tool.name, "description": tool.description, "inputSchema": {"json": tool.parameters}}} + ) + + return {"tools": tool_specs} if tool_specs else None + + +def _convert_to_bedrock_format(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Convert a list of ChatMessages to the format expected by Bedrock API. + Separates system messages and handles tool results and tool calls. + + :param messages: List of ChatMessages to convert + :return: Tuple of (system_prompts, non_system_messages) in Bedrock format + """ + system_prompts = [] + non_system_messages = [] + + for msg in messages: + if msg.is_from(ChatRole.SYSTEM): + system_prompts.append({"text": msg.text}) + continue + + # Handle tool results - must role these as user messages + if msg.tool_call_results: + tool_results = [] + for result in msg.tool_call_results: + try: + json_result = json.loads(result.result) + content = [{"json": json_result}] + except json.JSONDecodeError: + content = [{"text": result.result}] + + tool_results.append( + { + "toolResult": { + "toolUseId": result.origin.id, + "content": content, + **({"status": "error"} if result.error else {}), + } + } + ) + non_system_messages.append({"role": "user", "content": tool_results}) + continue + + content = [] + # Handle text content + if msg.text: + content.append({"text": msg.text}) + + # Handle tool calls + if msg.tool_calls: + for tool_call in msg.tool_calls: + content.append( + {"toolUse": {"toolUseId": tool_call.id, "name": tool_call.tool_name, "input": tool_call.arguments}} + ) + + if content: # Only add message if it has content + non_system_messages.append({"role": msg.role.value, "content": content}) + + return system_prompts, non_system_messages + + @component class AmazonBedrockChatGenerator: """ @@ -70,6 +146,7 @@ def __init__( stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, ): """ Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the @@ -103,6 +180,7 @@ def __init__( [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. :param boto3_config: The configuration for the boto3 client. + :param tools: A list of Tool objects that the model can use. Each tool should have a unique name. :raises ValueError: If the model name is empty or None. :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is @@ -120,6 +198,8 @@ def __init__( self.stop_words = stop_words or [] self.streaming_callback = streaming_callback self.boto3_config = boto3_config + _check_duplicate_tool_names(tools) + self.tools = tools def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -155,6 +235,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, @@ -167,6 +248,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, boto3_config=self.boto3_config, + tools=serialized_tools, ) @classmethod @@ -186,6 +268,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": data["init_parameters"], ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) + deserialize_tools_inplace(data["init_parameters"], key="tools") return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) @@ -194,6 +277,7 @@ def run( messages: List[ChatMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, ): generation_kwargs = generation_kwargs or {} @@ -209,20 +293,19 @@ def run( if key in merged_kwargs } - # Extract tool configuration if present - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + # Handle tools - either toolConfig or Haystack Tool objects but not both + tools = tools or self.tools + _check_duplicate_tool_names(tools) tool_config = merged_kwargs.pop("toolConfig", None) + if tools: + # Convert Haystack tools to Bedrock format + tool_config = _convert_tools_to_bedrock_format(tools) # Any remaining kwargs go to additionalModelRequestFields additional_fields = merged_kwargs if merged_kwargs else None - # Prepare system prompts and messages - system_prompts = [] - if messages and messages[0].is_from(ChatRole.SYSTEM): - system_prompts = [{"text": messages[0].text}] - messages = messages[1:] - - messages_list = [{"role": msg.role.value, "content": [{"text": msg.text}]} for msg in messages] + # Convert messages to Bedrock format + system_prompts, messages_list = _convert_to_bedrock_format(messages) # Build API parameters params = { @@ -256,6 +339,12 @@ def run( return {"replies": replies} def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """ + Extract ChatMessage replies from a Bedrock response. + + :param response_body: Raw response from Bedrock API + :return: List of ChatMessage objects + """ replies = [] if "output" in response_body and "message" in response_body["output"]: message = response_body["output"]["message"] @@ -280,17 +369,30 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C if "text" in content_block: replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) elif "toolUse" in content_block: - replies.append( - ChatMessage.from_assistant(json.dumps(content_block["toolUse"]), meta=base_meta.copy()) + # Convert tool use to ToolCall + tool_use = content_block["toolUse"] + tool_call = ToolCall( + id=tool_use.get("toolUseId"), + tool_name=tool_use.get("name"), + arguments=tool_use.get("input", {}), ) + replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) + return replies def process_streaming_response( self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] ) -> List[ChatMessage]: + """ + Process a streaming response from Bedrock. + + :param response_stream: EventStream from Bedrock API + :param streaming_callback: Callback for streaming chunks + :return: List of ChatMessage objects + """ replies = [] current_content = "" - current_tool_use = None + current_tool_call: Optional[Dict[str, Any]] = None base_meta = { "model": self.model, "index": 0, @@ -300,14 +402,14 @@ def process_streaming_response( if "contentBlockStart" in event: # Reset accumulators for new message current_content = "" - current_tool_use = None + current_tool_call = None block_start = event["contentBlockStart"] if "start" in block_start and "toolUse" in block_start["start"]: tool_start = block_start["start"]["toolUse"] - current_tool_use = { - "toolUseId": tool_start["toolUseId"], + current_tool_call = { + "id": tool_start["toolUseId"], "name": tool_start["name"], - "input": "", # Will accumulate deltas as string + "arguments": "", # Will accumulate deltas as string } elif "contentBlockDelta" in event: @@ -316,34 +418,38 @@ def process_streaming_response( delta_text = delta["text"] current_content += delta_text streaming_chunk = StreamingChunk(content=delta_text, meta=None) - # it only makes sense to call callback on text deltas streaming_callback(streaming_chunk) - elif "toolUse" in delta and current_tool_use: + elif "toolUse" in delta and current_tool_call: # Accumulate tool use input deltas - current_tool_use["input"] += delta["toolUse"].get("input", "") + current_tool_call["arguments"] += delta["toolUse"].get("input", "") + elif "contentBlockStop" in event: - if current_tool_use: + if current_tool_call: # Parse accumulated input if it's a JSON string try: - input_json = json.loads(current_tool_use["input"]) - current_tool_use["input"] = input_json + input_json = json.loads(current_tool_call["arguments"]) + current_tool_call["arguments"] = input_json except json.JSONDecodeError: # Keep as string if not valid JSON pass - tool_content = json.dumps(current_tool_use) - replies.append(ChatMessage.from_assistant(tool_content, meta=base_meta.copy())) + tool_call = ToolCall( + id=current_tool_call["id"], + tool_name=current_tool_call["name"], + arguments=current_tool_call["arguments"], + ) + replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) elif current_content: replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) elif "messageStop" in event: - # not 100% correct for multiple messages but no way around it + # Update finish reason for all replies for reply in replies: reply.meta["finish_reason"] = event["messageStop"].get("stopReason") elif "metadata" in event: metadata = event["metadata"] - # not 100% correct for multiple messages but no way around it + # Update usage stats for all replies for reply in replies: if "usage" in metadata: usage = metadata["usage"] diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index c2122163c..069e24d43 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,9 +1,9 @@ -import json from typing import Any, Dict, Optional import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.tools import Tool from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator @@ -32,6 +32,18 @@ def chat_messages(): return messages +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + return [tool] + + @pytest.mark.parametrize( "boto3_config", [ @@ -64,6 +76,7 @@ def test_to_dict(mock_boto3_session, boto3_config): "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": boto3_config, + "tools": None, }, } @@ -96,6 +109,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": boto3_config, + "tools": None, }, } ) @@ -206,7 +220,9 @@ def streaming_callback(chunk: StreamingChunk): @pytest.mark.integration def test_tools_use(self, model_name): """ - Test function calling with AWS Bedrock Anthropic adapter + Test tools use with passing the generation_kwargs={"toolConfig": tool_config} + and not the tools parameter. We support this because some users might want to use the toolConfig + parameter to pass the tool configuration to the model. """ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html tool_config = { @@ -244,43 +260,33 @@ def test_tools_use(self, model_name): assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.text, "First reply has no content" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert first_reply.meta, "First reply has no metadata" - - # Some models return thinking message as first and the second one as the tool call - if len(replies) > 1: - second_reply = replies[1] - assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.text, "Second reply has no content" - assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" - else: - # case where the model returns the tool call as the first message - # double check that the tool call is correct - tool_call = json.loads(first_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" + # Find the message with tool calls as in some models it is the first message, in some second + tool_message = None + for message in replies: + if message.tool_call: # Using tool_call instead of tool_calls to match existing code + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "top_song", f"{tool_call} does not contain the correct 'tool_name' value" + assert tool_call.arguments, f"Tool call {tool_call} does not contain 'arguments' value" + assert ( + tool_call.arguments["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'arguments' value" @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) @pytest.mark.integration def test_tools_use_with_streaming(self, model_name): """ - Test function calling with AWS Bedrock Anthropic adapter + Test tools use with streaming but with passing the generation_kwargs={"toolConfig": tool_config} + and not the tools parameter. We support this because some users might want to use the toolConfig + parameter to pass the tool configuration to the model. """ - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html tool_config = { "tools": [ { @@ -304,12 +310,10 @@ def test_tools_use_with_streaming(self, model_name): } } ], - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html "toolChoice": {"auto": {}}, } - messages = [] - messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + messages = [ChatMessage.from_user("What is the most popular song on WZPZ?")] client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk) response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) replies = response["replies"] @@ -322,29 +326,22 @@ def test_tools_use_with_streaming(self, model_name): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" - # Some models return thinking message as first and the second one as the tool call - if len(replies) > 1: - second_reply = replies[1] - assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.text, "Second reply has no content" - assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" - else: - # case where the model returns the tool call as the first message - # double check that the tool call is correct - tool_call = json.loads(first_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" + # Find the message containing the tool call + tool_message = None + for message in replies: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "top_song", f"{tool_call} does not contain the correct 'tool_name' value" + assert tool_call.arguments, f"{tool_call} does not contain 'arguments' value" + assert tool_call.arguments["sign"] == "WZPZ", f"{tool_call} does not contain the correct 'input' value" def test_extract_replies_from_response(self, mock_boto3_session): """ @@ -381,10 +378,10 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(tool_response) assert len(replies) == 1 - tool_content = json.loads(replies[0].text) - assert tool_content["toolUseId"] == "123" - assert tool_content["name"] == "test_tool" - assert tool_content["input"] == {"key": "value"} + tool_content = replies[0].tool_call + assert tool_content.id == "123" + assert tool_content.tool_name == "test_tool" + assert tool_content.arguments == {"key": "value"} assert replies[0].meta["finish_reason"] == "tool_call" assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40} @@ -406,10 +403,10 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(mixed_response) assert len(replies) == 2 assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer." - tool_content = json.loads(replies[1].text) - assert tool_content["toolUseId"] == "456" - assert tool_content["name"] == "search_tool" - assert tool_content["input"] == {"query": "test"} + tool_content = replies[1].tool_call + assert tool_content.id == "456" + assert tool_content.tool_name == "search_tool" + assert tool_content.arguments == {"query": "test"} def test_process_streaming_response(self, mock_boto3_session): """ @@ -452,7 +449,142 @@ def test_callback(chunk: StreamingChunk): assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Check tool use reply - tool_content = json.loads(replies[1].text) - assert tool_content["toolUseId"] == "123" - assert tool_content["name"] == "search_tool" - assert tool_content["input"] == {"query": "test"} + tool_content = replies[1].tool_call + assert tool_content.id == "123" + assert tool_content.tool_name == "search_tool" + assert tool_content.arguments == {"query": "test"} + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.integration + def test_live_run_with_tools(self, model_name, tools): + """ + Integration test that the AmazonBedrockChatGenerator component can run with tools. Here we are using the + Haystack tools parameter to pass the tool configuration to the model. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator(model=model_name, tools=tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_use" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) + @pytest.mark.integration + def test_live_run_with_tools_streaming(self, model_name, tools): + """ + Integration test that the AmazonBedrockChatGenerator component can run with the Haystack tools parameter. + and the streaming_callback parameter to get the streaming response. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator(model=model_name, tools=tools, streaming_callback=print_streaming_chunk) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_use" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + # @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + # @pytest.mark.integration + # def test_live_run_with_parallel_tools(self, model_name, tools): + # """ + # Integration test that the AmazonBedrockChatGenerator component can run with parallel tools. + # """ + # initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] + # component = AmazonBedrockChatGenerator(model=model_name, tools=tools) + # results = component.run(messages=initial_messages) + + # assert len(results["replies"]) == 1 + # message = results["replies"][0] + + # # Check tool calls + # assert len(message.tool_calls) == 2 + # tool_call_paris = message.tool_calls[0] + # assert isinstance(tool_call_paris, ToolCall) + # assert tool_call_paris.id is not None + # assert tool_call_paris.tool_name == "weather" + # assert tool_call_paris.arguments["city"] in {"Paris", "Berlin"} + # assert message.meta["finish_reason"] == "tool_use" + + # tool_call_berlin = message.tool_calls[1] + # assert isinstance(tool_call_berlin, ToolCall) + # assert tool_call_berlin.id is not None + # assert tool_call_berlin.tool_name == "weather" + # assert tool_call_berlin.arguments["city"] in {"Berlin", "Paris"} + + # # Send results from both tools + # new_messages = [ + # *initial_messages, + # message, + # ChatMessage.from_tool(tool_result="22° C", origin=tool_call_paris, error=False), + # ChatMessage.from_tool(tool_result="12° C", origin=tool_call_berlin, error=False), + # ] + + # # Get final response + # results = component.run(new_messages) + # message = results["replies"][0] + # assert not message.tool_calls + # assert len(message.text) > 0 + # assert "paris" in message.text.lower() + # assert "berlin" in message.text.lower() + # assert "22°" in message.text + # assert "12°" in message.text + # assert message.meta["finish_reason"] == "end_turn" From 85ca82f59cd61e69b33a3fcbb5762a9dea29bb0e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 20 Jan 2025 17:07:39 +0100 Subject: [PATCH 02/10] Remove test not needed --- .../tests/test_chat_generator.py | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 069e24d43..7d799c7b5 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -541,50 +541,3 @@ def test_live_run_with_tools_streaming(self, model_name, tools): assert not final_message.tool_call assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() - - # @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) - # @pytest.mark.integration - # def test_live_run_with_parallel_tools(self, model_name, tools): - # """ - # Integration test that the AmazonBedrockChatGenerator component can run with parallel tools. - # """ - # initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] - # component = AmazonBedrockChatGenerator(model=model_name, tools=tools) - # results = component.run(messages=initial_messages) - - # assert len(results["replies"]) == 1 - # message = results["replies"][0] - - # # Check tool calls - # assert len(message.tool_calls) == 2 - # tool_call_paris = message.tool_calls[0] - # assert isinstance(tool_call_paris, ToolCall) - # assert tool_call_paris.id is not None - # assert tool_call_paris.tool_name == "weather" - # assert tool_call_paris.arguments["city"] in {"Paris", "Berlin"} - # assert message.meta["finish_reason"] == "tool_use" - - # tool_call_berlin = message.tool_calls[1] - # assert isinstance(tool_call_berlin, ToolCall) - # assert tool_call_berlin.id is not None - # assert tool_call_berlin.tool_name == "weather" - # assert tool_call_berlin.arguments["city"] in {"Berlin", "Paris"} - - # # Send results from both tools - # new_messages = [ - # *initial_messages, - # message, - # ChatMessage.from_tool(tool_result="22° C", origin=tool_call_paris, error=False), - # ChatMessage.from_tool(tool_result="12° C", origin=tool_call_berlin, error=False), - # ] - - # # Get final response - # results = component.run(new_messages) - # message = results["replies"][0] - # assert not message.tool_calls - # assert len(message.text) > 0 - # assert "paris" in message.text.lower() - # assert "berlin" in message.text.lower() - # assert "22°" in message.text - # assert "12°" in message.text - # assert message.meta["finish_reason"] == "end_turn" From d29509ffe0cee3f73f17005b1ab57923ee8fe343 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 20 Jan 2025 18:01:34 +0100 Subject: [PATCH 03/10] Add actual pipeline integration test with tools --- .../tests/test_chat_generator.py | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 7d799c7b5..6b853ae4b 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,7 +1,9 @@ from typing import Any, Dict, Optional import pytest +from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk +from haystack.components.tools import ToolInvoker from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.tools import Tool @@ -34,12 +36,16 @@ def chat_messages(): @pytest.fixture def tools(): + def weather(city: str): + """Get weather for a given city.""" + return f"The weather in {city} is sunny and 32°C" + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} tool = Tool( name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, - function=lambda x: x, + function=weather, ) return [tool] @@ -541,3 +547,25 @@ def test_live_run_with_tools_streaming(self, model_name, tools): assert not final_message.tool_call assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.integration + def test_pipeline_with_amazon_bedrock_chat_generator(self, model_name, tools): + """ + Test that the AmazonBedrockChatGenerator component can be used in a pipeline + """ + + pipeline = Pipeline() + pipeline.add_component("generator", AmazonBedrockChatGenerator(model=model_name, tools=tools)) + pipeline.add_component("tool_invoker", ToolInvoker(tools=tools)) + + pipeline.connect("generator", "tool_invoker") + + results = pipeline.run( + data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + ) + + assert ( + "The weather in Paris is sunny and 32°C" + == results["tool_invoker"]["tool_messages"][0].tool_call_result.result + ) From 6fc8be2a19a5f2b506388bd46ded184014bc9469 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 21 Jan 2025 10:00:42 +0100 Subject: [PATCH 04/10] Extract instance functions to free standing --- .../amazon_bedrock/chat/chat_generator.py | 276 +++++++++--------- .../tests/test_chat_generator.py | 21 +- 2 files changed, 153 insertions(+), 144 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 162b9a6ce..aac352159 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -20,11 +20,11 @@ logger = logging.getLogger(__name__) -def _convert_tools_to_bedrock_format(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]: +def _format_tools_for_bedrock(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]: """ - Convert Haystack Tool(s) to Amazon Bedrock toolConfig format. + Format Haystack Tool(s) to Amazon Bedrock toolConfig format. - :param tools: List of Tool objects to convert + :param tools: List of Tool objects to format :return: Dictionary in Bedrock toolConfig format or None if no tools """ if not tools: @@ -39,12 +39,12 @@ def _convert_tools_to_bedrock_format(tools: Optional[List[Tool]] = None) -> Opti return {"tools": tool_specs} if tool_specs else None -def _convert_to_bedrock_format(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: +def _format_messages_for_bedrock(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """ - Convert a list of ChatMessages to the format expected by Bedrock API. + Format a list of ChatMessages to the format expected by Bedrock API. Separates system messages and handles tool results and tool calls. - :param messages: List of ChatMessages to convert + :param messages: List of ChatMessages to format :return: Tuple of (system_prompts, non_system_messages) in Bedrock format """ system_prompts = [] @@ -95,6 +95,135 @@ def _convert_to_bedrock_format(messages: List[ChatMessage]) -> Tuple[List[Dict[s return system_prompts, non_system_messages +def _parse_bedrock_completion_response(response_body: Dict[str, Any], model: str) -> List[ChatMessage]: + """ + Parse a Bedrock response to a list of ChatMessage objects. + + :param response_body: Raw response from Bedrock API + :param model: The model ID used for generation + :return: List of ChatMessage objects + """ + replies = [] + if "output" in response_body and "message" in response_body["output"]: + message = response_body["output"]["message"] + if message["role"] == "assistant": + content_blocks = message["content"] + + # Common meta information + base_meta = { + "model": model, + "index": 0, + "finish_reason": response_body.get("stopReason"), + "usage": { + # OpenAI's format for usage for cross ChatGenerator compatibility + "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), + "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), + "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), + }, + } + + # Process each content block separately + for content_block in content_blocks: + if "text" in content_block: + replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) + elif "toolUse" in content_block: + # Convert tool use to ToolCall + tool_use = content_block["toolUse"] + tool_call = ToolCall( + id=tool_use.get("toolUseId"), + tool_name=tool_use.get("name"), + arguments=tool_use.get("input", {}), + ) + replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) + + return replies + + +def _parse_bedrock_streaming_chunks( + response_stream: EventStream, + streaming_callback: Callable[[StreamingChunk], None], + model: str, +) -> List[ChatMessage]: + """ + Parse a streaming response from Bedrock. + + :param response_stream: EventStream from Bedrock API + :param streaming_callback: Callback for streaming chunks + :param model: The model ID used for generation + :return: List of ChatMessage objects + """ + replies = [] + current_content = "" + current_tool_call: Optional[Dict[str, Any]] = None + base_meta = { + "model": model, + "index": 0, + } + + for event in response_stream: + if "contentBlockStart" in event: + # Reset accumulators for new message + current_content = "" + current_tool_call = None + block_start = event["contentBlockStart"] + if "start" in block_start and "toolUse" in block_start["start"]: + tool_start = block_start["start"]["toolUse"] + current_tool_call = { + "id": tool_start["toolUseId"], + "name": tool_start["name"], + "arguments": "", # Will accumulate deltas as string + } + + elif "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + delta_text = delta["text"] + current_content += delta_text + streaming_chunk = StreamingChunk(content=delta_text, meta=None) + streaming_callback(streaming_chunk) + elif "toolUse" in delta and current_tool_call: + # Accumulate tool use input deltas + current_tool_call["arguments"] += delta["toolUse"].get("input", "") + + elif "contentBlockStop" in event: + if current_tool_call: + # Parse accumulated input if it's a JSON string + try: + input_json = json.loads(current_tool_call["arguments"]) + current_tool_call["arguments"] = input_json + except json.JSONDecodeError: + # Keep as string if not valid JSON + pass + + tool_call = ToolCall( + id=current_tool_call["id"], + tool_name=current_tool_call["name"], + arguments=current_tool_call["arguments"], + ) + replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) + elif current_content: + replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) + + elif "messageStop" in event: + # Update finish reason for all replies + for reply in replies: + reply.meta["finish_reason"] = event["messageStop"].get("stopReason") + + elif "metadata" in event: + metadata = event["metadata"] + # Update usage stats for all replies + for reply in replies: + if "usage" in metadata: + usage = metadata["usage"] + reply.meta["usage"] = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + + return replies + + @component class AmazonBedrockChatGenerator: """ @@ -298,14 +427,14 @@ def run( _check_duplicate_tool_names(tools) tool_config = merged_kwargs.pop("toolConfig", None) if tools: - # Convert Haystack tools to Bedrock format - tool_config = _convert_tools_to_bedrock_format(tools) + # Format Haystack tools to Bedrock format + tool_config = _format_tools_for_bedrock(tools) # Any remaining kwargs go to additionalModelRequestFields additional_fields = merged_kwargs if merged_kwargs else None - # Convert messages to Bedrock format - system_prompts, messages_list = _convert_to_bedrock_format(messages) + # Format messages to Bedrock format + system_prompts, messages_list = _format_messages_for_bedrock(messages) # Build API parameters params = { @@ -328,135 +457,12 @@ def run( if not response_stream: msg = "No stream found in the response." raise AmazonBedrockInferenceError(msg) - replies = self.process_streaming_response(response_stream, callback) + replies = _parse_bedrock_streaming_chunks(response_stream, callback, self.model) else: response = self.client.converse(**params) - replies = self.extract_replies_from_response(response) + replies = _parse_bedrock_completion_response(response, self.model) except ClientError as exception: msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception return {"replies": replies} - - def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extract ChatMessage replies from a Bedrock response. - - :param response_body: Raw response from Bedrock API - :return: List of ChatMessage objects - """ - replies = [] - if "output" in response_body and "message" in response_body["output"]: - message = response_body["output"]["message"] - if message["role"] == "assistant": - content_blocks = message["content"] - - # Common meta information - base_meta = { - "model": self.model, - "index": 0, - "finish_reason": response_body.get("stopReason"), - "usage": { - # OpenAI's format for usage for cross ChatGenerator compatibility - "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), - "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), - "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), - }, - } - - # Process each content block separately - for content_block in content_blocks: - if "text" in content_block: - replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) - elif "toolUse" in content_block: - # Convert tool use to ToolCall - tool_use = content_block["toolUse"] - tool_call = ToolCall( - id=tool_use.get("toolUseId"), - tool_name=tool_use.get("name"), - arguments=tool_use.get("input", {}), - ) - replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) - - return replies - - def process_streaming_response( - self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] - ) -> List[ChatMessage]: - """ - Process a streaming response from Bedrock. - - :param response_stream: EventStream from Bedrock API - :param streaming_callback: Callback for streaming chunks - :return: List of ChatMessage objects - """ - replies = [] - current_content = "" - current_tool_call: Optional[Dict[str, Any]] = None - base_meta = { - "model": self.model, - "index": 0, - } - - for event in response_stream: - if "contentBlockStart" in event: - # Reset accumulators for new message - current_content = "" - current_tool_call = None - block_start = event["contentBlockStart"] - if "start" in block_start and "toolUse" in block_start["start"]: - tool_start = block_start["start"]["toolUse"] - current_tool_call = { - "id": tool_start["toolUseId"], - "name": tool_start["name"], - "arguments": "", # Will accumulate deltas as string - } - - elif "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - delta_text = delta["text"] - current_content += delta_text - streaming_chunk = StreamingChunk(content=delta_text, meta=None) - streaming_callback(streaming_chunk) - elif "toolUse" in delta and current_tool_call: - # Accumulate tool use input deltas - current_tool_call["arguments"] += delta["toolUse"].get("input", "") - - elif "contentBlockStop" in event: - if current_tool_call: - # Parse accumulated input if it's a JSON string - try: - input_json = json.loads(current_tool_call["arguments"]) - current_tool_call["arguments"] = input_json - except json.JSONDecodeError: - # Keep as string if not valid JSON - pass - - tool_call = ToolCall( - id=current_tool_call["id"], - tool_name=current_tool_call["name"], - arguments=current_tool_call["arguments"], - ) - replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) - elif current_content: - replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) - - elif "messageStop" in event: - # Update finish reason for all replies - for reply in replies: - reply.meta["finish_reason"] = event["messageStop"].get("stopReason") - - elif "metadata" in event: - metadata = event["metadata"] - # Update usage stats for all replies - for reply in replies: - if "usage" in metadata: - usage = metadata["usage"] - reply.meta["usage"] = { - "prompt_tokens": usage.get("inputTokens", 0), - "completion_tokens": usage.get("outputTokens", 0), - "total_tokens": usage.get("totalTokens", 0), - } - - return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 6b853ae4b..3fa60e7bf 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -8,6 +8,10 @@ from haystack.tools import Tool from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator +from haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator import ( + _parse_bedrock_completion_response, + _parse_bedrock_streaming_chunks, +) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" MODELS_TO_TEST = [ @@ -353,8 +357,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): """ Test that extract_replies_from_response correctly processes both text and tool use responses """ - generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") - + model = "anthropic.claude-3-5-sonnet-20240620-v1:0" # Test case 1: Simple text response text_response = { "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, @@ -362,11 +365,11 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, } - replies = generator.extract_replies_from_response(text_response) + replies = _parse_bedrock_completion_response(text_response, model) assert len(replies) == 1 assert replies[0].text == "This is a test response" assert replies[0].role == ChatRole.ASSISTANT - assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["model"] == model assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} @@ -382,7 +385,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40}, } - replies = generator.extract_replies_from_response(tool_response) + replies = _parse_bedrock_completion_response(tool_response, model) assert len(replies) == 1 tool_content = replies[0].tool_call assert tool_content.id == "123" @@ -406,7 +409,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60}, } - replies = generator.extract_replies_from_response(mixed_response) + replies = _parse_bedrock_completion_response(mixed_response, model) assert len(replies) == 2 assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer." tool_content = replies[1].tool_call @@ -418,7 +421,7 @@ def test_process_streaming_response(self, mock_boto3_session): """ Test that process_streaming_response correctly handles streaming events and accumulates responses """ - generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + model = "anthropic.claude-3-5-sonnet-20240620-v1:0" streaming_chunks = [] @@ -439,7 +442,7 @@ def test_callback(chunk: StreamingChunk): {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}, ] - replies = generator.process_streaming_response(events, test_callback) + replies = _parse_bedrock_streaming_chunks(events, test_callback, model) # Verify streaming chunks were received for text content assert len(streaming_chunks) == 2 @@ -450,7 +453,7 @@ def test_callback(chunk: StreamingChunk): assert len(replies) == 2 # Check text reply assert replies[0].text == "Let me help you." - assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["model"] == model assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} From fae3723275cac33cc1f34a09c15926f5eb5ff8ab Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 22 Jan 2025 13:53:10 +0100 Subject: [PATCH 05/10] No need to test serde on all models --- integrations/amazon_bedrock/tests/test_chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 3fa60e7bf..29effdf8e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -551,7 +551,7 @@ def test_live_run_with_tools_streaming(self, model_name, tools): assert len(final_message.text) > 0 assert "paris" in final_message.text.lower() - @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.parametrize("model_name", [MODELS_TO_TEST_WITH_TOOLS[0]]) # just one model is enough @pytest.mark.integration def test_pipeline_with_amazon_bedrock_chat_generator(self, model_name, tools): """ From b9db10c53db4bd97c2253562763a94c51eb9fc59 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 22 Jan 2025 13:55:44 +0100 Subject: [PATCH 06/10] Add serde test --- .../tests/test_chat_generator.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 29effdf8e..54d0894c9 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -572,3 +572,51 @@ def test_pipeline_with_amazon_bedrock_chat_generator(self, model_name, tools): "The weather in Paris is sunny and 32°C" == results["tool_invoker"]["tool_messages"][0].tool_call_result.result ) + + @pytest.mark.parametrize("model_name", [MODELS_TO_TEST_WITH_TOOLS[0]]) # just one model is enough + @pytest.mark.integration + def test_pipeline_with_amazon_bedrock_chat_generator_serde(self, model_name, tools): + """ + Test that the AmazonBedrockChatGenerator component can be serialized and deserialized in a pipeline + """ + # Create original pipeline + pipeline = Pipeline() + pipeline.add_component("generator", AmazonBedrockChatGenerator(model=model_name, tools=tools)) + pipeline.add_component("tool_invoker", ToolInvoker(tools=tools)) + pipeline.connect("generator", "tool_invoker") + + # Serialize and deserialize + pipeline_dict = pipeline.to_dict() + + # Verify tools in serialized dict + generator_tools = pipeline_dict["components"]["generator"]["init_parameters"]["tools"] + tool_invoker_tools = pipeline_dict["components"]["tool_invoker"]["init_parameters"]["tools"] + + # Both components should have the same tool configuration + assert generator_tools == tool_invoker_tools + assert len(generator_tools) == 1 + + # Verify tool details + tool_dict = generator_tools[0] + assert tool_dict["type"] == "haystack.tools.tool.Tool" + assert tool_dict["data"]["name"] == "weather" + assert tool_dict["data"]["description"] == "useful to determine the weather in a given location" + assert tool_dict["data"]["parameters"] == { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + assert tool_dict["data"]["function"] == "tests.test_chat_generator.weather" + + # Load pipeline and verify it works + loaded_pipeline = Pipeline.from_dict(pipeline_dict) + + # Run the deserialized pipeline + results = loaded_pipeline.run( + data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + ) + + assert ( + "The weather in Paris is sunny and 32°C" + == results["tool_invoker"]["tool_messages"][0].tool_call_result.result + ) From fbee1c5a0995db6e2c2c74c24b5e574e69dbd077 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 22 Jan 2025 15:05:38 +0100 Subject: [PATCH 07/10] Fix serde test --- .../tests/test_chat_generator.py | 9 +++--- .../generators/cohere/chat/chat_generator.py | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 54d0894c9..1c8dcbf08 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -29,6 +29,11 @@ STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] +def weather(city: str): + """Get weather for a given city.""" + return f"The weather in {city} is sunny and 32°C" + + @pytest.fixture def chat_messages(): messages = [ @@ -40,10 +45,6 @@ def chat_messages(): @pytest.fixture def tools(): - def weather(city: str): - """Get weather for a given city.""" - return f"The weather in {city} is sunny and 32°C" - tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} tool = Tool( name="weather", diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 33e7c98f6..ea5408ea8 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -6,6 +6,7 @@ from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from haystack.tools import Tool with LazyImport(message="Run 'pip install cohere'") as cohere_import: import cohere @@ -42,6 +43,7 @@ def __init__( streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, api_base_url: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, **kwargs, ): """ @@ -77,6 +79,8 @@ def __init__( `accurate` results or `fast` results. - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. + :param tools: List of Tool instances that the model can use. Each tool should have a name, + description and parameters schema. """ cohere_import.check() @@ -89,6 +93,7 @@ def __init__( self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.generation_kwargs = generation_kwargs + self.tools = tools or [] self.model_parameters = kwargs self.client = cohere.Client( api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" @@ -115,6 +120,7 @@ def to_dict(self) -> Dict[str, Any]: api_base_url=self.api_base_url, api_key=self.api_key.to_dict(), generation_kwargs=self.generation_kwargs, + tools=self.tools, ) @classmethod @@ -154,6 +160,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, """ # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + # Add tools to generation kwargs if we have any + if self.tools: + generation_kwargs["tools"] = self._convert_tools_to_cohere_format() + chat_history = [self._message_to_dict(m) for m in messages[:-1]] if self.streaming_callback: response = self.client.chat_stream( @@ -234,3 +245,20 @@ def _build_message(self, cohere_response): } ) return message + + def _convert_tools_to_cohere_format(self) -> List[Dict[str, Any]]: + """ + Converts Haystack Tool instances to Cohere's tool format + """ + cohere_tools = [] + for tool in self.tools: + cohere_tool = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } + } + cohere_tools.append(cohere_tool) + return cohere_tools From b2acc10df68f35fb00d4206605415fb5d0dd6831 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 22 Jan 2025 16:10:41 +0100 Subject: [PATCH 08/10] Lint --- .../generators/cohere/chat/chat_generator.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index ea5408ea8..915b44021 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -4,9 +4,9 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.lazy_imports import LazyImport +from haystack.tools import Tool from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from haystack.tools import Tool with LazyImport(message="Run 'pip install cohere'") as cohere_import: import cohere @@ -160,11 +160,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, """ # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - + # Add tools to generation kwargs if we have any if self.tools: generation_kwargs["tools"] = self._convert_tools_to_cohere_format() - + chat_history = [self._message_to_dict(m) for m in messages[:-1]] if self.streaming_callback: response = self.client.chat_stream( @@ -254,11 +254,7 @@ def _convert_tools_to_cohere_format(self) -> List[Dict[str, Any]]: for tool in self.tools: cohere_tool = { "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } + "function": {"name": tool.name, "description": tool.description, "parameters": tool.parameters}, } cohere_tools.append(cohere_tool) return cohere_tools From 2319409597782fb4417971d39030fd349e79fa35 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 22 Jan 2025 16:23:49 +0100 Subject: [PATCH 09/10] Always pack thinking + tool call into single ChatMessage --- .../generators/amazon_bedrock/chat/chat_generator.py | 11 ++++++++--- .../amazon_bedrock/tests/test_chat_generator.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index aac352159..2fcfdb3b5 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -122,10 +122,12 @@ def _parse_bedrock_completion_response(response_body: Dict[str, Any], model: str }, } - # Process each content block separately + # Process all content blocks and combine them into a single message + text_content = [] + tool_calls = [] for content_block in content_blocks: if "text" in content_block: - replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) + text_content.append(content_block["text"]) elif "toolUse" in content_block: # Convert tool use to ToolCall tool_use = content_block["toolUse"] @@ -134,7 +136,10 @@ def _parse_bedrock_completion_response(response_body: Dict[str, Any], model: str tool_name=tool_use.get("name"), arguments=tool_use.get("input", {}), ) - replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) + tool_calls.append(tool_call) + + # Create a single ChatMessage with combined text and tool calls + replies.append(ChatMessage.from_assistant(" ".join(text_content), tool_calls=tool_calls, meta=base_meta)) return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 1c8dcbf08..6b19a9e21 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -411,9 +411,9 @@ def test_extract_replies_from_response(self, mock_boto3_session): } replies = _parse_bedrock_completion_response(mixed_response, model) - assert len(replies) == 2 + assert len(replies) == 1 assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer." - tool_content = replies[1].tool_call + tool_content = replies[0].tool_call assert tool_content.id == "456" assert tool_content.tool_name == "search_tool" assert tool_content.arguments == {"query": "test"} From 39436da3c75fd65601cbd39edc0cf024bcef8614 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 22 Jan 2025 16:28:08 +0100 Subject: [PATCH 10/10] Revert accidental changes --- .../generators/cohere/chat/chat_generator.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 915b44021..33e7c98f6 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -4,7 +4,6 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.lazy_imports import LazyImport -from haystack.tools import Tool from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -43,7 +42,6 @@ def __init__( streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, api_base_url: Optional[str] = None, generation_kwargs: Optional[Dict[str, Any]] = None, - tools: Optional[List[Tool]] = None, **kwargs, ): """ @@ -79,8 +77,6 @@ def __init__( `accurate` results or `fast` results. - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. - :param tools: List of Tool instances that the model can use. Each tool should have a name, - description and parameters schema. """ cohere_import.check() @@ -93,7 +89,6 @@ def __init__( self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.generation_kwargs = generation_kwargs - self.tools = tools or [] self.model_parameters = kwargs self.client = cohere.Client( api_key=self.api_key.resolve_value(), base_url=self.api_base_url, client_name="haystack" @@ -120,7 +115,6 @@ def to_dict(self) -> Dict[str, Any]: api_base_url=self.api_base_url, api_key=self.api_key.to_dict(), generation_kwargs=self.generation_kwargs, - tools=self.tools, ) @classmethod @@ -160,11 +154,6 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, """ # update generation kwargs by merging with the generation kwargs passed to the run method generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - - # Add tools to generation kwargs if we have any - if self.tools: - generation_kwargs["tools"] = self._convert_tools_to_cohere_format() - chat_history = [self._message_to_dict(m) for m in messages[:-1]] if self.streaming_callback: response = self.client.chat_stream( @@ -245,16 +234,3 @@ def _build_message(self, cohere_response): } ) return message - - def _convert_tools_to_cohere_format(self) -> List[Dict[str, Any]]: - """ - Converts Haystack Tool instances to Cohere's tool format - """ - cohere_tools = [] - for tool in self.tools: - cohere_tool = { - "type": "function", - "function": {"name": tool.name, "description": tool.description, "parameters": tool.parameters}, - } - cohere_tools.append(cohere_tool) - return cohere_tools