Skip to content

Commit

Permalink
fix: cohere streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
elisalimli committed Apr 29, 2024
1 parent bb29b68 commit 6fb91fb
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion libs/superagent/app/agents/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS = [
LLMProvider.GROQ,
LLMProvider.BEDROCK,
LLMProvider.COHERE_CHAT,
]

intermediate_steps = []
Expand Down Expand Up @@ -236,7 +237,10 @@ async def _completion(self, **kwargs) -> Any:
await self.streaming_callback.on_llm_start()

# TODO: Remove this when Groq and Bedrock supports streaming with tools
if self.llm_data.llm.provider in self.NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS:
if (
self.llm_data.llm.provider in self.NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS
and len(self.tools) > 0
):
logger.info(
f"Disabling streaming for {self.llm_data.llm.provider}, as tools are used"
)
Expand All @@ -251,6 +255,7 @@ async def _completion(self, **kwargs) -> Any:

for chunk in res:
new_message = chunk.choices[0].delta.dict()
print("new_message", new_message)
# clean up tool calls
if new_message.get("tool_calls"):
new_message["role"] = "assistant"
Expand All @@ -277,6 +282,7 @@ async def _completion(self, **kwargs) -> Any:

output = self._cleanup_output(output)

print("self._stream_directly", self._stream_directly)
if not self._stream_directly:
await self._stream_by_lines(output)

Expand Down Expand Up @@ -310,6 +316,7 @@ async def ainvoke(self, input, *_, **kwargs):
stream=self.enable_streaming,
**self.llm_data.params.dict(exclude_unset=True),
)
print("output", output)

return {
"intermediate_steps": self.intermediate_steps,
Expand Down

0 comments on commit 6fb91fb

Please sign in to comment.