Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#432] Groq Provider tool call tweaks #811

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 71 additions & 23 deletions llama_stack/providers/remote/inference/groq/groq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import json
import warnings
from typing import AsyncGenerator, Literal
from typing import AsyncGenerator, Literal, Union

from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
Expand All @@ -30,6 +30,8 @@

from llama_models.llama3.api.datatypes import ToolParamDefinition

from pydantic import BaseModel

from llama_stack.apis.common.content_types import (
TextDelta,
ToolCallDelta,
Expand Down Expand Up @@ -150,15 +152,26 @@ def convert_chat_completion_response(
_convert_groq_tool_call(tool_call)
for tool_call in choice.message.tool_calls
]
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
completion_message=CompletionMessage(
stop_reason=StopReason.end_of_message,
content=json.dumps(tool_calls, default=lambda x: x.model_dump()),
),
logprobs=None,
)
else:
# Otherwise, return tool calls as normal
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
stop_reason=StopReason.end_of_message,
# Content is not optional
content="",
),
logprobs=None,
)
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
Expand Down Expand Up @@ -214,15 +227,27 @@ async def convert_chat_completion_response_stream(

# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
if isinstance(tool_call, ToolCall):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
)
)
else:
# Otherwise it's an UnparseableToolCall - return the raw tool call
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
content=tool_call.model_dump_json(),
parse_status=ToolCallParseStatus.failed,
),
)
)
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
Expand All @@ -234,12 +259,35 @@ async def convert_chat_completion_response_stream(
event_type = ChatCompletionResponseEventType.progress


def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
class UnparseableToolCall(BaseModel):
"""
A ToolCall with arguments that are not valid JSON.
Mirrors the ToolCall schema, but with arguments as a string.
"""

call_id: str
tool_name: str
arguments: str
Copy link
Contributor Author

@aidando73 aidando73 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ToolCall.arguments must be a Dict, so we can't keep it within ToolCall

@json_schema_type
class ToolCall(BaseModel):
    call_id: str
    tool_name: Union[BuiltinTool, str]
    arguments: Dict[str, RecursiveType]



def _convert_groq_tool_call(
tool_call: ChatCompletionMessageToolCall,
) -> Union[ToolCall, UnparseableToolCall]:
"""
Convert a Groq tool call to a ToolCall.
Returns an UnparseableToolCall if the tool call is not valid JSON.
"""
try:
arguments = json.loads(tool_call.function.arguments)
except Exception as e:
return UnparseableToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=tool_call.function.arguments,
)

return ToolCall(
call_id=tool_call.id,
tool_name=tool_call.function.name,
# Note that Groq may return a string that is not valid JSON here
# So this may raise a 500 error. Going to leave this as is to see
# how big of an issue this is and what we can do about it.
arguments=json.loads(tool_call.function.arguments),
arguments=arguments,
)
55 changes: 55 additions & 0 deletions llama_stack/providers/tests/inference/groq/test_groq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from groq.types.shared.function_definition import FunctionDefinition
from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponseEventType,
Expand Down Expand Up @@ -347,6 +348,26 @@ def test_converts_multiple_tool_calls(self):
),
]

def test_converts_unparseable_tool_calls(self):
response = self._dummy_chat_completion_response_with_tool_call()
response.choices[0].message.tool_calls = [
ChatCompletionMessageToolCall(
id="tool_call_id",
type="function",
function=Function(
name="log",
arguments="(number=10, base=2)",
),
),
]

converted = convert_chat_completion_response(response)

assert (
converted.completion_message.content
== '[{"call_id": "tool_call_id", "tool_name": "log", "arguments": "(number=10, base=2)"}]'
)

def _dummy_chat_completion_response(self):
return ChatCompletion(
id="chatcmpl-123",
Expand Down Expand Up @@ -478,6 +499,40 @@ def tool_call_stream():
arguments={"origin": "AU", "destination": "LAX"},
)

@pytest.mark.asyncio
async def test_returns_tool_calls_stream_with_unparseable_tool_calls(self):
def tool_call_stream():
chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.tool_calls = [
ChoiceDeltaToolCall(
index=0,
type="function",
id="tool_call_id",
function=ChoiceDeltaToolCallFunction(
name="get_flight_info",
arguments="(origin=AU, destination=LAX)",
),
),
]
yield chunk

chunk = self._dummy_chat_completion_chunk_with_tool_call()
chunk.choices[0].delta.content = None
chunk.choices[0].finish_reason = "stop"
yield chunk

stream = tool_call_stream()
converted = convert_chat_completion_response_stream(stream)

iter = converted.__aiter__()
chunk = await iter.__anext__()
assert chunk.event.event_type == ChatCompletionResponseEventType.start
assert (
chunk.event.delta.content
== '{"call_id":"tool_call_id","tool_name":"get_flight_info","arguments":"(origin=AU, destination=LAX)"}'
)
assert chunk.event.delta.parse_status == ToolCallParseStatus.failed

def _dummy_chat_completion_chunk(self):
return ChatCompletionChunk(
id="chatcmpl-123",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,6 @@ async def test_chat_completion_with_tool_calling(
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


messages = sample_messages + [
UserMessage(
Expand Down
Loading