Skip to content

Commit

Permalink
Fix tool calls with response model (#1812)
Browse files Browse the repository at this point in the history
## Description

Tool calling does not work correctly with response models. This will use
the response model if there is one, overriding any tool calling. The
tool calls are then displayed separately.

This also solves the issue for teams.

Fixes #1792
  • Loading branch information
dirkbrnd authored Jan 20, 2025
1 parent e1e52b8 commit 8954eea
Show file tree
Hide file tree
Showing 15 changed files with 181 additions and 23 deletions.
7 changes: 1 addition & 6 deletions cookbook/providers/azure_openai/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,5 @@
)
knowledge_base.load(recreate=False) # Comment out after first run

agent = Agent(
model=AzureOpenAIChat(id="gpt-4o"),
knowledge=knowledge_base,
show_tool_calls=True,
debug_mode=True
)
agent = Agent(model=AzureOpenAIChat(id="gpt-4o"), knowledge=knowledge_base, show_tool_calls=True, debug_mode=True)
agent.print_response("How to make Thai curry?", markdown=True)
7 changes: 5 additions & 2 deletions cookbook/providers/mistral/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, Field
from phi.agent import Agent, RunResponse # noqa
from phi.model.mistral import MistralChat
from phi.tools.duckduckgo import DuckDuckGo

mistral_api_key = os.getenv("MISTRAL_API_KEY")

Expand All @@ -25,13 +26,15 @@ class MovieScript(BaseModel):
id="mistral-large-latest",
api_key=mistral_api_key,
),
tools=[DuckDuckGo()],
description="You help people write movie scripts.",
response_model=MovieScript,
# debug_mode=True,
show_tool_calls=True,
debug_mode=True,
)

# Get the response in a variable
# json_mode_response: RunResponse = json_mode_agent.run("New York")
# pprint(json_mode_response.content)

json_mode_agent.print_response("New York")
json_mode_agent.print_response("Find a cool movie idea about London and write it.")
12 changes: 12 additions & 0 deletions cookbook/teams/01_hn_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,22 @@
2. Run: `python cookbook/teams/01_hn_team.py` to run the agent
"""

from typing import List

from pydantic import BaseModel

from phi.agent import Agent
from phi.tools.hackernews import HackerNews
from phi.tools.duckduckgo import DuckDuckGo
from phi.tools.newspaper4k import Newspaper4k


class Article(BaseModel):
title: str
summary: str
reference_links: List[str]


hn_researcher = Agent(
name="HackerNews Researcher",
role="Gets top stories from hackernews.",
Expand Down Expand Up @@ -37,6 +48,7 @@
"Then, ask the web searcher to search for each story to get more information.",
"Finally, provide a thoughtful and engaging summary.",
],
response_model=Article,
show_tool_calls=True,
markdown=True,
)
Expand Down
1 change: 1 addition & 0 deletions cookbook/tools/reddit_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
- password: Your Reddit account password
"""

from phi.agent import Agent
from phi.tools.reddit import RedditTools

Expand Down
2 changes: 1 addition & 1 deletion cookbook/tools/trello_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@
)

agent.print_response(
"Create a board called ai-agent and inside it create list called 'todo' and 'doing' and inside each of them create card called 'create agent'",
"Create a board called ai-agent and inside it create list called 'todo' and 'doing' and inside each of them create card called 'create agent'",
stream=True,
)
17 changes: 15 additions & 2 deletions phi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,8 +1991,13 @@ def run(

# If a response_model is set, return the response as a structured output
if self.response_model is not None and self.parse_response:
# Set show_tool_calls=False if we have response_model
self.show_tool_calls = False
logger.debug("Setting show_tool_calls=False as response_model is set")

# Set stream=False and run the agent
logger.debug("Setting stream=False as response_model is set")

run_response: RunResponse = next(
self._run(
message=message,
Expand Down Expand Up @@ -2036,6 +2041,7 @@ def run(
self.run_response.content_type = self.response_model.__name__
else:
logger.warning("Failed to convert response to response_model")

except Exception as e:
logger.warning(f"Failed to convert response to output model: {e}")
else:
Expand Down Expand Up @@ -2306,6 +2312,10 @@ async def arun(

# If a response_model is set, return the response as a structured output
if self.response_model is not None and self.parse_response:
# Set show_tool_calls=False if we have a response_model
self.show_tool_calls = False
logger.debug("Setting show_tool_calls=False as response_model is set")

# Set stream=False and run the agent
logger.debug("Setting stream=False as response_model is set")
run_response = await self._arun(
Expand All @@ -2331,8 +2341,7 @@ async def arun(
structured_output = None
try:
structured_output = self.response_model.model_validate_json(run_response.content)
except ValidationError as exc:
logger.warning(f"Failed to convert response to pydantic model: {exc}")
except ValidationError:
# Check if response starts with ```json
if run_response.content.startswith("```json"):
run_response.content = run_response.content.replace("```json\n", "").replace("\n```", "")
Expand All @@ -2348,6 +2357,9 @@ async def arun(
if self.run_response is not None:
self.run_response.content = structured_output
self.run_response.content_type = self.response_model.__name__
else:
logger.warning("Failed to convert response to response_model")

except Exception as e:
logger.warning(f"Failed to convert response to output model: {e}")
else:
Expand Down Expand Up @@ -2816,6 +2828,7 @@ def print_response(
_response_content += resp.content
if resp.extra_data is not None and resp.extra_data.reasoning_steps is not None:
reasoning_steps = resp.extra_data.reasoning_steps

response_content_stream = Markdown(_response_content) if self.markdown else _response_content

panels = [status]
Expand Down
2 changes: 1 addition & 1 deletion phi/document/chunking/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
try:
from chonkie import SemanticChunker
except ImportError:
logger.warning("`chonkie` is required for semantic chunking, please install using `pip install chonkie`")
logger.warning("`chonkie` is required for semantic chunking, please install using `pip install chonkie[all]`")


class SemanticChunking(ChunkingStrategy):
Expand Down
2 changes: 1 addition & 1 deletion phi/embedder/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_embedding(self, text: Union[str, List[str]]) -> List[float]:
model = SentenceTransformer(model_name_or_path=self.model)
embedding = model.encode(text)
try:
return embedding
return embedding # type: ignore
except Exception as e:
logger.warning(e)
return []
Expand Down
8 changes: 8 additions & 0 deletions phi/playground/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def agent_run(
for file in files:
if file.content_type == "application/pdf":
from phi.document.reader.pdf import PDFReader

contents = file.file.read()
pdf_file = BytesIO(contents)
pdf_file.name = file.filename
Expand All @@ -152,6 +153,7 @@ def agent_run(
agent.knowledge.load_documents(file_content)
elif file.content_type == "text/csv":
from phi.document.reader.csv_reader import CSVReader

contents = file.file.read()
csv_file = BytesIO(contents)
csv_file.name = file.filename
Expand All @@ -160,6 +162,7 @@ def agent_run(
agent.knowledge.load_documents(file_content)
elif file.content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
from phi.document.reader.docx import DocxReader

contents = file.file.read()
docx_file = BytesIO(contents)
docx_file.name = file.filename
Expand All @@ -168,6 +171,7 @@ def agent_run(
agent.knowledge.load_documents(file_content)
elif file.content_type == "text/plain":
from phi.document.reader.text import TextReader

contents = file.file.read()
text_file = BytesIO(contents)
text_file.name = file.filename
Expand Down Expand Up @@ -519,6 +523,7 @@ async def agent_run(
for file in files:
if file.content_type == "application/pdf":
from phi.document.reader.pdf import PDFReader

contents = await file.read()
pdf_file = BytesIO(contents)
pdf_file.name = file.filename
Expand All @@ -527,6 +532,7 @@ async def agent_run(
agent.knowledge.load_documents(file_content)
elif file.content_type == "text/csv":
from phi.document.reader.csv_reader import CSVReader

contents = await file.read()
csv_file = BytesIO(contents)
csv_file.name = file.filename
Expand All @@ -535,6 +541,7 @@ async def agent_run(
agent.knowledge.load_documents(file_content)
elif file.content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
from phi.document.reader.docx import DocxReader

contents = await file.read()
docx_file = BytesIO(contents)
docx_file.name = file.filename
Expand All @@ -543,6 +550,7 @@ async def agent_run(
agent.knowledge.load_documents(file_content)
elif file.content_type == "text/plain":
from phi.document.reader.text import TextReader

contents = await file.read()
text_file = BytesIO(contents)
text_file.name = file.filename
Expand Down
18 changes: 9 additions & 9 deletions phi/tools/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ def get_channel_history(self, channel: str, limit: int = 100) -> str:
try:
response = self.client.conversations_history(channel=channel, limit=limit)
messages = [
{
"text": msg.get("text", ""),
"user": "webhook" if msg.get("subtype") == "bot_message" else msg.get("user", "unknown"),
"ts": msg.get("ts", ""),
"sub_type": msg.get("subtype", "unknown"),
"attachments": msg.get("attachments", []) if msg.get("subtype") == "bot_message" else "n/a"
}
for msg in response.get("messages", [])
]
{
"text": msg.get("text", ""),
"user": "webhook" if msg.get("subtype") == "bot_message" else msg.get("user", "unknown"),
"ts": msg.get("ts", ""),
"sub_type": msg.get("subtype", "unknown"),
"attachments": msg.get("attachments", []) if msg.get("subtype") == "bot_message" else "n/a",
}
for msg in response.get("messages", [])
]
return json.dumps(messages)
except SlackApiError as e:
logger.error(f"Error getting channel history: {e}")
Expand Down
61 changes: 61 additions & 0 deletions phi/utils/string.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import hashlib
import json
from typing import Optional, Dict, Any


def hash_string_sha256(input_string):
Expand All @@ -15,3 +17,62 @@ def hash_string_sha256(input_string):
hex_digest = sha256_hash.hexdigest()

return hex_digest


def extract_valid_json(content: str) -> Optional[Dict[str, Any]]:
"""
Extract the first valid JSON object from a string and return the JSON object
along with the rest of the string without the JSON.
Args:
content (str): The input string containing potential JSON data.
Returns:
Tuple[Optional[Dict[str, Any]], str]:
- Extracted JSON dictionary if valid, else None.
- The rest of the string without the extracted JSON.
"""
search_start = 0
while True:
# Find the next opening brace
start_idx = content.find("{", search_start)
if start_idx == -1:
# No more '{' found; stop searching
return None

# Track brace depth
brace_depth = 0
# This will store the end of the matching closing brace once found
end_idx = None

for i in range(start_idx, len(content)):
char = content[i]
if char == "{":
brace_depth += 1
elif char == "}":
brace_depth -= 1

# If brace_depth returns to 0, we’ve found a potential JSON substring
if brace_depth == 0:
end_idx = i
break

# If we never returned to depth 0, it means we couldn't find a matching '}'
if end_idx is None:
return None

# Extract the candidate substring
candidate = content[start_idx : end_idx + 1]

# Try to parse it
try:
parsed = json.loads(candidate)
# If parsed successfully, check if it's a dict
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
# Not valid JSON, keep going
pass

# Move just past the current opening brace to look for another candidate
search_start = start_idx + 1
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ check_untyped_defs = true
no_implicit_optional = true
warn_unused_configs = true
plugins = ["pydantic.mypy"]
exclude = ["phienv*", "aienv*", "scratch*", "wip*", "tmp*", "cookbook/assistants/examples/*", "phi/assistant/openai/*"]
exclude = ["phienv*", "aienv*", "scratch*", "wip*", "tmp*", "cookbook/assistants/examples/*", "phi/assistant/openai/*", "tests/*"]

[[tool.mypy.overrides]]
module = [
Expand Down
File renamed without changes.
Empty file added tests/unit/utils/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions tests/unit/utils/test_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from phi.utils.string import extract_valid_json


def test_extract_valid_json_with_valid_json():
content = 'Here is some text {"key": "value"} and more text.'
expected_json = {"key": "value"}
extracted_json = extract_valid_json(content)
assert extracted_json == expected_json


def test_extract_valid_json_with_nested_json():
content = 'Start {"key": {"nested_key": "nested_value"}} End'
expected_json = {"key": {"nested_key": "nested_value"}}
extracted_json = extract_valid_json(content)
assert extracted_json == expected_json


def test_extract_valid_json_with_multiple_json_objects():
content = 'First {"key1": "value1"} Second {"key2": "value2"}'
expected_json = {"key1": "value1"} # Only the first JSON should be returned
extracted_json = extract_valid_json(content)
assert extracted_json == expected_json


def test_extract_valid_json_with_no_json():
content = "This is a string without JSON."
extracted_json = extract_valid_json(content)
assert extracted_json is None


def test_extract_valid_json_with_invalid_json():
content = "This string contains {invalid JSON}."
extracted_json = extract_valid_json(content)
assert extracted_json is None


def test_extract_valid_json_with_json_array():
content = 'Here is a JSON array: ["item1", "item2"].'
extracted_json = extract_valid_json(content)
assert extracted_json is None # Only JSON objects are extracted


def test_extract_valid_json_with_empty_json():
content = "Some text {} more text."
expected_json = {}
extracted_json = extract_valid_json(content)
assert extracted_json == expected_json


def test_extract_valid_json_with_multiline_json():
content = """
Here is some text {
"key": "value",
"another_key": "another_value"
} and more text.
"""
expected_json = {"key": "value", "another_key": "another_value"}
extracted_json = extract_valid_json(content)
assert extracted_json == expected_json


def test_extract_valid_json_with_json_in_quotes():
content = 'Text before "{\\"key\\": \\"value\\"}" text after.'
extracted_json = extract_valid_json(content)
assert extracted_json is None # JSON inside quotes should not be parsed

0 comments on commit 8954eea

Please sign in to comment.