diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 7044eb453..07f8db679 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "jsonschema", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bcf11414c..2fcfdb3b5 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,12 +1,13 @@ import json import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple from botocore.config import Config from botocore.eventstream import EventStream from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -19,6 +20,215 @@ logger = logging.getLogger(__name__) +def _format_tools_for_bedrock(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]: + """ + Format Haystack Tool(s) to Amazon Bedrock toolConfig format. + + :param tools: List of Tool objects to format + :return: Dictionary in Bedrock toolConfig format or None if no tools + """ + if not tools: + return None + + tool_specs = [] + for tool in tools: + tool_specs.append( + {"toolSpec": {"name": tool.name, "description": tool.description, "inputSchema": {"json": tool.parameters}}} + ) + + return {"tools": tool_specs} if tool_specs else None + + +def _format_messages_for_bedrock(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """ + Format a list of ChatMessages to the format expected by Bedrock API. + Separates system messages and handles tool results and tool calls. + + :param messages: List of ChatMessages to format + :return: Tuple of (system_prompts, non_system_messages) in Bedrock format + """ + system_prompts = [] + non_system_messages = [] + + for msg in messages: + if msg.is_from(ChatRole.SYSTEM): + system_prompts.append({"text": msg.text}) + continue + + # Handle tool results - must role these as user messages + if msg.tool_call_results: + tool_results = [] + for result in msg.tool_call_results: + try: + json_result = json.loads(result.result) + content = [{"json": json_result}] + except json.JSONDecodeError: + content = [{"text": result.result}] + + tool_results.append( + { + "toolResult": { + "toolUseId": result.origin.id, + "content": content, + **({"status": "error"} if result.error else {}), + } + } + ) + non_system_messages.append({"role": "user", "content": tool_results}) + continue + + content = [] + # Handle text content + if msg.text: + content.append({"text": msg.text}) + + # Handle tool calls + if msg.tool_calls: + for tool_call in msg.tool_calls: + content.append( + {"toolUse": {"toolUseId": tool_call.id, "name": tool_call.tool_name, "input": tool_call.arguments}} + ) + + if content: # Only add message if it has content + non_system_messages.append({"role": msg.role.value, "content": content}) + + return system_prompts, non_system_messages + + +def _parse_bedrock_completion_response(response_body: Dict[str, Any], model: str) -> List[ChatMessage]: + """ + Parse a Bedrock response to a list of ChatMessage objects. + + :param response_body: Raw response from Bedrock API + :param model: The model ID used for generation + :return: List of ChatMessage objects + """ + replies = [] + if "output" in response_body and "message" in response_body["output"]: + message = response_body["output"]["message"] + if message["role"] == "assistant": + content_blocks = message["content"] + + # Common meta information + base_meta = { + "model": model, + "index": 0, + "finish_reason": response_body.get("stopReason"), + "usage": { + # OpenAI's format for usage for cross ChatGenerator compatibility + "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), + "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), + "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), + }, + } + + # Process all content blocks and combine them into a single message + text_content = [] + tool_calls = [] + for content_block in content_blocks: + if "text" in content_block: + text_content.append(content_block["text"]) + elif "toolUse" in content_block: + # Convert tool use to ToolCall + tool_use = content_block["toolUse"] + tool_call = ToolCall( + id=tool_use.get("toolUseId"), + tool_name=tool_use.get("name"), + arguments=tool_use.get("input", {}), + ) + tool_calls.append(tool_call) + + # Create a single ChatMessage with combined text and tool calls + replies.append(ChatMessage.from_assistant(" ".join(text_content), tool_calls=tool_calls, meta=base_meta)) + + return replies + + +def _parse_bedrock_streaming_chunks( + response_stream: EventStream, + streaming_callback: Callable[[StreamingChunk], None], + model: str, +) -> List[ChatMessage]: + """ + Parse a streaming response from Bedrock. + + :param response_stream: EventStream from Bedrock API + :param streaming_callback: Callback for streaming chunks + :param model: The model ID used for generation + :return: List of ChatMessage objects + """ + replies = [] + current_content = "" + current_tool_call: Optional[Dict[str, Any]] = None + base_meta = { + "model": model, + "index": 0, + } + + for event in response_stream: + if "contentBlockStart" in event: + # Reset accumulators for new message + current_content = "" + current_tool_call = None + block_start = event["contentBlockStart"] + if "start" in block_start and "toolUse" in block_start["start"]: + tool_start = block_start["start"]["toolUse"] + current_tool_call = { + "id": tool_start["toolUseId"], + "name": tool_start["name"], + "arguments": "", # Will accumulate deltas as string + } + + elif "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + delta_text = delta["text"] + current_content += delta_text + streaming_chunk = StreamingChunk(content=delta_text, meta=None) + streaming_callback(streaming_chunk) + elif "toolUse" in delta and current_tool_call: + # Accumulate tool use input deltas + current_tool_call["arguments"] += delta["toolUse"].get("input", "") + + elif "contentBlockStop" in event: + if current_tool_call: + # Parse accumulated input if it's a JSON string + try: + input_json = json.loads(current_tool_call["arguments"]) + current_tool_call["arguments"] = input_json + except json.JSONDecodeError: + # Keep as string if not valid JSON + pass + + tool_call = ToolCall( + id=current_tool_call["id"], + tool_name=current_tool_call["name"], + arguments=current_tool_call["arguments"], + ) + replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy())) + elif current_content: + replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) + + elif "messageStop" in event: + # Update finish reason for all replies + for reply in replies: + reply.meta["finish_reason"] = event["messageStop"].get("stopReason") + + elif "metadata" in event: + metadata = event["metadata"] + # Update usage stats for all replies + for reply in replies: + if "usage" in metadata: + usage = metadata["usage"] + reply.meta["usage"] = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + + return replies + + @component class AmazonBedrockChatGenerator: """ @@ -70,6 +280,7 @@ def __init__( stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, ): """ Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the @@ -103,6 +314,7 @@ def __init__( [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. :param boto3_config: The configuration for the boto3 client. + :param tools: A list of Tool objects that the model can use. Each tool should have a unique name. :raises ValueError: If the model name is empty or None. :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is @@ -120,6 +332,8 @@ def __init__( self.stop_words = stop_words or [] self.streaming_callback = streaming_callback self.boto3_config = boto3_config + _check_duplicate_tool_names(tools) + self.tools = tools def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -155,6 +369,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, @@ -167,6 +382,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, boto3_config=self.boto3_config, + tools=serialized_tools, ) @classmethod @@ -186,6 +402,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": data["init_parameters"], ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) + deserialize_tools_inplace(data["init_parameters"], key="tools") return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) @@ -194,6 +411,7 @@ def run( messages: List[ChatMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, ): generation_kwargs = generation_kwargs or {} @@ -209,20 +427,19 @@ def run( if key in merged_kwargs } - # Extract tool configuration if present - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + # Handle tools - either toolConfig or Haystack Tool objects but not both + tools = tools or self.tools + _check_duplicate_tool_names(tools) tool_config = merged_kwargs.pop("toolConfig", None) + if tools: + # Format Haystack tools to Bedrock format + tool_config = _format_tools_for_bedrock(tools) # Any remaining kwargs go to additionalModelRequestFields additional_fields = merged_kwargs if merged_kwargs else None - # Prepare system prompts and messages - system_prompts = [] - if messages and messages[0].is_from(ChatRole.SYSTEM): - system_prompts = [{"text": messages[0].text}] - messages = messages[1:] - - messages_list = [{"role": msg.role.value, "content": [{"text": msg.text}]} for msg in messages] + # Format messages to Bedrock format + system_prompts, messages_list = _format_messages_for_bedrock(messages) # Build API parameters params = { @@ -245,112 +462,12 @@ def run( if not response_stream: msg = "No stream found in the response." raise AmazonBedrockInferenceError(msg) - replies = self.process_streaming_response(response_stream, callback) + replies = _parse_bedrock_streaming_chunks(response_stream, callback, self.model) else: response = self.client.converse(**params) - replies = self.extract_replies_from_response(response) + replies = _parse_bedrock_completion_response(response, self.model) except ClientError as exception: msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception return {"replies": replies} - - def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - replies = [] - if "output" in response_body and "message" in response_body["output"]: - message = response_body["output"]["message"] - if message["role"] == "assistant": - content_blocks = message["content"] - - # Common meta information - base_meta = { - "model": self.model, - "index": 0, - "finish_reason": response_body.get("stopReason"), - "usage": { - # OpenAI's format for usage for cross ChatGenerator compatibility - "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), - "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), - "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), - }, - } - - # Process each content block separately - for content_block in content_blocks: - if "text" in content_block: - replies.append(ChatMessage.from_assistant(content_block["text"], meta=base_meta.copy())) - elif "toolUse" in content_block: - replies.append( - ChatMessage.from_assistant(json.dumps(content_block["toolUse"]), meta=base_meta.copy()) - ) - return replies - - def process_streaming_response( - self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] - ) -> List[ChatMessage]: - replies = [] - current_content = "" - current_tool_use = None - base_meta = { - "model": self.model, - "index": 0, - } - - for event in response_stream: - if "contentBlockStart" in event: - # Reset accumulators for new message - current_content = "" - current_tool_use = None - block_start = event["contentBlockStart"] - if "start" in block_start and "toolUse" in block_start["start"]: - tool_start = block_start["start"]["toolUse"] - current_tool_use = { - "toolUseId": tool_start["toolUseId"], - "name": tool_start["name"], - "input": "", # Will accumulate deltas as string - } - - elif "contentBlockDelta" in event: - delta = event["contentBlockDelta"]["delta"] - if "text" in delta: - delta_text = delta["text"] - current_content += delta_text - streaming_chunk = StreamingChunk(content=delta_text, meta=None) - # it only makes sense to call callback on text deltas - streaming_callback(streaming_chunk) - elif "toolUse" in delta and current_tool_use: - # Accumulate tool use input deltas - current_tool_use["input"] += delta["toolUse"].get("input", "") - elif "contentBlockStop" in event: - if current_tool_use: - # Parse accumulated input if it's a JSON string - try: - input_json = json.loads(current_tool_use["input"]) - current_tool_use["input"] = input_json - except json.JSONDecodeError: - # Keep as string if not valid JSON - pass - - tool_content = json.dumps(current_tool_use) - replies.append(ChatMessage.from_assistant(tool_content, meta=base_meta.copy())) - elif current_content: - replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy())) - - elif "messageStop" in event: - # not 100% correct for multiple messages but no way around it - for reply in replies: - reply.meta["finish_reason"] = event["messageStop"].get("stopReason") - - elif "metadata" in event: - metadata = event["metadata"] - # not 100% correct for multiple messages but no way around it - for reply in replies: - if "usage" in metadata: - usage = metadata["usage"] - reply.meta["usage"] = { - "prompt_tokens": usage.get("inputTokens", 0), - "completion_tokens": usage.get("outputTokens", 0), - "total_tokens": usage.get("totalTokens", 0), - } - - return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index c2122163c..6b19a9e21 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,11 +1,17 @@ -import json from typing import Any, Dict, Optional import pytest +from haystack import Pipeline from haystack.components.generators.utils import print_streaming_chunk +from haystack.components.tools import ToolInvoker from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.tools import Tool from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator +from haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator import ( + _parse_bedrock_completion_response, + _parse_bedrock_streaming_chunks, +) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" MODELS_TO_TEST = [ @@ -23,6 +29,11 @@ STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] +def weather(city: str): + """Get weather for a given city.""" + return f"The weather in {city} is sunny and 32°C" + + @pytest.fixture def chat_messages(): messages = [ @@ -32,6 +43,18 @@ def chat_messages(): return messages +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=weather, + ) + return [tool] + + @pytest.mark.parametrize( "boto3_config", [ @@ -64,6 +87,7 @@ def test_to_dict(mock_boto3_session, boto3_config): "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": boto3_config, + "tools": None, }, } @@ -96,6 +120,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "boto3_config": boto3_config, + "tools": None, }, } ) @@ -206,7 +231,9 @@ def streaming_callback(chunk: StreamingChunk): @pytest.mark.integration def test_tools_use(self, model_name): """ - Test function calling with AWS Bedrock Anthropic adapter + Test tools use with passing the generation_kwargs={"toolConfig": tool_config} + and not the tools parameter. We support this because some users might want to use the toolConfig + parameter to pass the tool configuration to the model. """ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html tool_config = { @@ -244,43 +271,33 @@ def test_tools_use(self, model_name): assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.text, "First reply has no content" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert first_reply.meta, "First reply has no metadata" - - # Some models return thinking message as first and the second one as the tool call - if len(replies) > 1: - second_reply = replies[1] - assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.text, "Second reply has no content" - assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" - else: - # case where the model returns the tool call as the first message - # double check that the tool call is correct - tool_call = json.loads(first_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" + # Find the message with tool calls as in some models it is the first message, in some second + tool_message = None + for message in replies: + if message.tool_call: # Using tool_call instead of tool_calls to match existing code + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "top_song", f"{tool_call} does not contain the correct 'tool_name' value" + assert tool_call.arguments, f"Tool call {tool_call} does not contain 'arguments' value" + assert ( + tool_call.arguments["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'arguments' value" @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) @pytest.mark.integration def test_tools_use_with_streaming(self, model_name): """ - Test function calling with AWS Bedrock Anthropic adapter + Test tools use with streaming but with passing the generation_kwargs={"toolConfig": tool_config} + and not the tools parameter. We support this because some users might want to use the toolConfig + parameter to pass the tool configuration to the model. """ - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html tool_config = { "tools": [ { @@ -304,12 +321,10 @@ def test_tools_use_with_streaming(self, model_name): } } ], - # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html "toolChoice": {"auto": {}}, } - messages = [] - messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + messages = [ChatMessage.from_user("What is the most popular song on WZPZ?")] client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk) response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) replies = response["replies"] @@ -322,36 +337,28 @@ def test_tools_use_with_streaming(self, model_name): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" - # Some models return thinking message as first and the second one as the tool call - if len(replies) > 1: - second_reply = replies[1] - assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" - assert second_reply.text, "Second reply has no content" - assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" - tool_call = json.loads(second_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" - else: - # case where the model returns the tool call as the first message - # double check that the tool call is correct - tool_call = json.loads(first_reply.text) - assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" - assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" - assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert ( - tool_call["input"]["sign"] == "WZPZ" - ), f"Tool call {tool_call} does not contain the correct 'input' value" + # Find the message containing the tool call + tool_message = None + for message in replies: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "top_song", f"{tool_call} does not contain the correct 'tool_name' value" + assert tool_call.arguments, f"{tool_call} does not contain 'arguments' value" + assert tool_call.arguments["sign"] == "WZPZ", f"{tool_call} does not contain the correct 'input' value" def test_extract_replies_from_response(self, mock_boto3_session): """ Test that extract_replies_from_response correctly processes both text and tool use responses """ - generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") - + model = "anthropic.claude-3-5-sonnet-20240620-v1:0" # Test case 1: Simple text response text_response = { "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, @@ -359,11 +366,11 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, } - replies = generator.extract_replies_from_response(text_response) + replies = _parse_bedrock_completion_response(text_response, model) assert len(replies) == 1 assert replies[0].text == "This is a test response" assert replies[0].role == ChatRole.ASSISTANT - assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["model"] == model assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} @@ -379,12 +386,12 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40}, } - replies = generator.extract_replies_from_response(tool_response) + replies = _parse_bedrock_completion_response(tool_response, model) assert len(replies) == 1 - tool_content = json.loads(replies[0].text) - assert tool_content["toolUseId"] == "123" - assert tool_content["name"] == "test_tool" - assert tool_content["input"] == {"key": "value"} + tool_content = replies[0].tool_call + assert tool_content.id == "123" + assert tool_content.tool_name == "test_tool" + assert tool_content.arguments == {"key": "value"} assert replies[0].meta["finish_reason"] == "tool_call" assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40} @@ -403,19 +410,19 @@ def test_extract_replies_from_response(self, mock_boto3_session): "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60}, } - replies = generator.extract_replies_from_response(mixed_response) - assert len(replies) == 2 + replies = _parse_bedrock_completion_response(mixed_response, model) + assert len(replies) == 1 assert replies[0].text == "Let me help you with that. I'll use the search tool to find the answer." - tool_content = json.loads(replies[1].text) - assert tool_content["toolUseId"] == "456" - assert tool_content["name"] == "search_tool" - assert tool_content["input"] == {"query": "test"} + tool_content = replies[0].tool_call + assert tool_content.id == "456" + assert tool_content.tool_name == "search_tool" + assert tool_content.arguments == {"query": "test"} def test_process_streaming_response(self, mock_boto3_session): """ Test that process_streaming_response correctly handles streaming events and accumulates responses """ - generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + model = "anthropic.claude-3-5-sonnet-20240620-v1:0" streaming_chunks = [] @@ -436,7 +443,7 @@ def test_callback(chunk: StreamingChunk): {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}, ] - replies = generator.process_streaming_response(events, test_callback) + replies = _parse_bedrock_streaming_chunks(events, test_callback, model) # Verify streaming chunks were received for text content assert len(streaming_chunks) == 2 @@ -447,12 +454,170 @@ def test_callback(chunk: StreamingChunk): assert len(replies) == 2 # Check text reply assert replies[0].text == "Let me help you." - assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["model"] == model assert replies[0].meta["finish_reason"] == "complete" assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Check tool use reply - tool_content = json.loads(replies[1].text) - assert tool_content["toolUseId"] == "123" - assert tool_content["name"] == "search_tool" - assert tool_content["input"] == {"query": "test"} + tool_content = replies[1].tool_call + assert tool_content.id == "123" + assert tool_content.tool_name == "search_tool" + assert tool_content.arguments == {"query": "test"} + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.integration + def test_live_run_with_tools(self, model_name, tools): + """ + Integration test that the AmazonBedrockChatGenerator component can run with tools. Here we are using the + Haystack tools parameter to pass the tool configuration to the model. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator(model=model_name, tools=tools) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_use" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) + @pytest.mark.integration + def test_live_run_with_tools_streaming(self, model_name, tools): + """ + Integration test that the AmazonBedrockChatGenerator component can run with the Haystack tools parameter. + and the streaming_callback parameter to get the streaming response. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")] + component = AmazonBedrockChatGenerator(model=model_name, tools=tools, streaming_callback=print_streaming_chunk) + results = component.run(messages=initial_messages) + + assert len(results["replies"]) > 0, "No replies received" + + # Find the message with tool calls + tool_message = None + for message in results["replies"]: + if message.tool_call: + tool_message = message + break + + assert tool_message is not None, "No message with tool call found" + assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance" + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant" + + tool_call = tool_message.tool_call + assert tool_call.id, "Tool call does not contain value for 'id' key" + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert tool_message.meta["finish_reason"] == "tool_use" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C", origin=tool_call), + ] + # Pass the tool result to the model to get the final response + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_call + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + + @pytest.mark.parametrize("model_name", [MODELS_TO_TEST_WITH_TOOLS[0]]) # just one model is enough + @pytest.mark.integration + def test_pipeline_with_amazon_bedrock_chat_generator(self, model_name, tools): + """ + Test that the AmazonBedrockChatGenerator component can be used in a pipeline + """ + + pipeline = Pipeline() + pipeline.add_component("generator", AmazonBedrockChatGenerator(model=model_name, tools=tools)) + pipeline.add_component("tool_invoker", ToolInvoker(tools=tools)) + + pipeline.connect("generator", "tool_invoker") + + results = pipeline.run( + data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + ) + + assert ( + "The weather in Paris is sunny and 32°C" + == results["tool_invoker"]["tool_messages"][0].tool_call_result.result + ) + + @pytest.mark.parametrize("model_name", [MODELS_TO_TEST_WITH_TOOLS[0]]) # just one model is enough + @pytest.mark.integration + def test_pipeline_with_amazon_bedrock_chat_generator_serde(self, model_name, tools): + """ + Test that the AmazonBedrockChatGenerator component can be serialized and deserialized in a pipeline + """ + # Create original pipeline + pipeline = Pipeline() + pipeline.add_component("generator", AmazonBedrockChatGenerator(model=model_name, tools=tools)) + pipeline.add_component("tool_invoker", ToolInvoker(tools=tools)) + pipeline.connect("generator", "tool_invoker") + + # Serialize and deserialize + pipeline_dict = pipeline.to_dict() + + # Verify tools in serialized dict + generator_tools = pipeline_dict["components"]["generator"]["init_parameters"]["tools"] + tool_invoker_tools = pipeline_dict["components"]["tool_invoker"]["init_parameters"]["tools"] + + # Both components should have the same tool configuration + assert generator_tools == tool_invoker_tools + assert len(generator_tools) == 1 + + # Verify tool details + tool_dict = generator_tools[0] + assert tool_dict["type"] == "haystack.tools.tool.Tool" + assert tool_dict["data"]["name"] == "weather" + assert tool_dict["data"]["description"] == "useful to determine the weather in a given location" + assert tool_dict["data"]["parameters"] == { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + assert tool_dict["data"]["function"] == "tests.test_chat_generator.weather" + + # Load pipeline and verify it works + loaded_pipeline = Pipeline.from_dict(pipeline_dict) + + # Run the deserialized pipeline + results = loaded_pipeline.run( + data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}} + ) + + assert ( + "The weather in Paris is sunny and 32°C" + == results["tool_invoker"]["tool_messages"][0].tool_call_result.result + )