diff --git a/fern/apis/prod/openapi/openapi.yaml b/fern/apis/prod/openapi/openapi.yaml
index 27aba78d1..3b8cdd66a 100644
--- a/fern/apis/prod/openapi/openapi.yaml
+++ b/fern/apis/prod/openapi/openapi.yaml
@@ -2,7 +2,7 @@ openapi: 3.0.2
info:
title: Superagent
description: 🥷 Run AI-agents with an API
- version: 0.2.29
+ version: 0.2.32
servers:
- url: https://api.beta.superagent.sh
paths:
@@ -195,8 +195,8 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
security:
- HTTPBearer: []
- x-fern-sdk-group-name: agent
x-fern-sdk-method-name: invoke
+ x-fern-sdk-group-name: agent
/api/v1/agents/{agent_id}/llms:
post:
tags:
@@ -1507,6 +1507,33 @@ paths:
$ref: '#/components/schemas/HTTPValidationError'
security:
- HTTPBearer: []
+ delete:
+ tags:
+ - Vector Database
+ summary: Delete
+ description: Delete a Vector Database
+ operationId: delete_api_v1_vector_dbs__vector_db_id__delete
+ parameters:
+ - required: true
+ schema:
+ title: Vector Db Id
+ type: string
+ name: vector_db_id
+ in: path
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema: {}
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ security:
+ - HTTPBearer: []
patch:
tags:
- Vector Database
@@ -1783,6 +1810,8 @@ components:
- TOGETHER_AI
- ANTHROPIC
- BEDROCK
+ - GROQ
+ - MISTRAL
type: string
description: An enumeration.
OpenAiAssistantParameters:
diff --git a/libs/superagent/app/agents/base.py b/libs/superagent/app/agents/base.py
index e00786ef0..305a9cb8a 100644
--- a/libs/superagent/app/agents/base.py
+++ b/libs/superagent/app/agents/base.py
@@ -6,7 +6,7 @@
from app.models.request import LLMParams as LLMParamsRequest
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
-from prisma.enums import AgentType
+from prisma.enums import AgentType, LLMProvider
from prisma.models import LLM, Agent
@@ -21,9 +21,21 @@ class LLMParams(BaseModel):
class LLMData(BaseModel):
llm: LLM
params: LLMParams
+ model: str
class AgentBase(ABC):
+ _input: str
+ _messages: list = []
+ prompt: Any
+ tools: Any
+ session_id: str
+ enable_streaming: bool
+ output_schema: str
+ callbacks: List[CustomAsyncIteratorCallbackHandler]
+ agent_data: Agent
+ llm_data: LLMData
+
def __init__(
self,
session_id: str,
@@ -40,10 +52,6 @@ def __init__(
self.llm_data = llm_data
self.agent_data = agent_data
- _input: str
- prompt: Any
- tools: Any
-
@property
def input(self):
return self._input
@@ -52,6 +60,14 @@ def input(self):
def input(self, value: str):
self._input = value
+ @property
+ def messages(self):
+ return self._messages
+
+ @messages.setter
+ def messages(self, value: list):
+ self._messages = value
+
@property
@abstractmethod
def prompt(self) -> Any:
@@ -95,7 +111,31 @@ def llm_data(self):
**(params),
}
- return LLMData(llm=llm, params=LLMParams.parse_obj(options))
+ params = LLMParams(
+ temperature=options.get("temperature"),
+ max_tokens=options.get("max_tokens"),
+ aws_access_key_id=(
+ options.get("aws_access_key_id")
+ if llm.provider == LLMProvider.BEDROCK
+ else None
+ ),
+ aws_secret_access_key=(
+ options.get("aws_secret_access_key")
+ if llm.provider == LLMProvider.BEDROCK
+ else None
+ ),
+ aws_region_name=(
+ options.get("aws_region_name")
+ if llm.provider == LLMProvider.BEDROCK
+ else None
+ ),
+ )
+
+ return LLMData(
+ llm=llm,
+ params=LLMParams.parse_obj(options),
+ model=self.agent_data.llmModel or self.agent_data.metadata.get("model"),
+ )
async def get_agent(self):
if self.agent_data.type == AgentType.OPENAI_ASSISTANT:
diff --git a/libs/superagent/app/agents/langchain.py b/libs/superagent/app/agents/langchain.py
index 73754b58c..c5fc39e95 100644
--- a/libs/superagent/app/agents/langchain.py
+++ b/libs/superagent/app/agents/langchain.py
@@ -46,7 +46,7 @@ def _get_llm(self):
if llm_data.llm.provider == LLMProvider.OPENAI:
return ChatOpenAI(
- model=LLM_MAPPING[self.agent_data.llmModel],
+ model=LLM_MAPPING[self.llm_data.model],
openai_api_key=llm_data.llm.apiKey,
streaming=self.enable_streaming,
callbacks=self.callbacks,
diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py
index 647a92760..9836e1fec 100644
--- a/libs/superagent/app/agents/llm.py
+++ b/libs/superagent/app/agents/llm.py
@@ -7,7 +7,7 @@
from langchain_core.agents import AgentActionMessageLog
from langchain_core.messages import AIMessage
from langchain_core.utils.function_calling import convert_to_openai_function
-from litellm import acompletion, completion
+from litellm import completion, get_llm_provider, get_supported_openai_params
from app.agents.base import AgentBase
from app.tools import get_tools
@@ -25,9 +25,9 @@
async def call_tool(
agent_data: Agent, session_id: str, function: Any
) -> tuple[AgentActionMessageLog, Any]:
- name = function.name
+ name = function.get("name")
try:
- args = json.loads(function.arguments)
+ args = json.loads(function.get("arguments"))
except Exception as e:
logger.error(f"Error parsing function arguments for {name}: {e}")
raise e
@@ -44,33 +44,46 @@ async def call_tool(
if not tool_to_call:
raise Exception(f"Function {name} not found in tools")
- res = await tool_to_call._arun(**args)
-
- return (
- AgentActionMessageLog(
- tool=name,
- tool_input=args,
- log=f"\nInvoking: `{name}` with `{args}`\n\n\n",
- message_log=[
- AIMessage(
- content="",
- additional_kwargs={
- "function_call": {
- "arguments": args,
- "name": name,
- }
- },
- )
- ],
- ),
- res,
+ logging.info(f"Calling tool {name} with arguments {args}")
+
+ action_log = AgentActionMessageLog(
+ tool=name,
+ tool_input=args,
+ log=f"\nInvoking: `{name}` with `{args}`\n\n\n",
+ message_log=[
+ AIMessage(
+ content="",
+ additional_kwargs={
+ "function_call": {
+ "arguments": args,
+ "name": name,
+ }
+ },
+ )
+ ],
)
+ tool_res = None
+ try:
+ tool_res = await tool_to_call._arun(**args)
+ logging.info(f"Tool {name} returned {tool_res}")
+ except Exception as e:
+ tool_res = f"Error calling {tool_to_call.name} tool with arguments {args}: {e}"
+ logging.error(f"Error calling tool {name}: {e}")
+
+ return (action_log, tool_res, tool_to_call.return_direct)
class LLMAgent(AgentBase):
@property
def tools(self):
- pass
+ tools = get_tools(
+ agent_data=self.agent_data,
+ session_id=self.session_id,
+ )
+ return [
+ {"type": "function", "function": convert_to_openai_function(tool)}
+ for tool in tools
+ ]
@property
def prompt(self):
@@ -91,8 +104,163 @@ def prompt(self):
return prompt
@property
- def messages(self):
- return [
+ def _is_tool_calling_supported(self):
+ (model, custom_llm_provider, _, _) = get_llm_provider(self.llm_data.model)
+ supported_params = get_supported_openai_params(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+
+ return "tools" in supported_params
+
+ async def _stream_by_lines(self, output: str):
+ output_by_lines = output.split("\n")
+ if len(output_by_lines) > 1:
+ for line in output_by_lines:
+ await self.streaming_callback.on_llm_new_token(line)
+ await self.streaming_callback.on_llm_new_token("\n")
+ else:
+ await self.streaming_callback.on_llm_new_token(output_by_lines[0])
+
+ async def get_agent(self):
+ if self._is_tool_calling_supported:
+ logger.info("Using native function calling")
+ return AgentExecutor(**self.__dict__)
+
+ return AgentExecutorOpenAIFunc(**self.__dict__)
+
+
+class AgentExecutor(LLMAgent):
+ """Agent Executor for LLM (with native function calling)"""
+
+ NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS = [
+ LLMProvider.GROQ,
+ LLMProvider.BEDROCK,
+ ]
+
+ intermediate_steps = []
+
+ async def _execute_tool_calls(self, tool_calls: list[dict], **kwargs):
+ messages: list = kwargs.get("messages")
+ for tool_call in tool_calls:
+ intermediate_step = await call_tool(
+ agent_data=self.agent_data,
+ session_id=self.session_id,
+ function=tool_call.get("function"),
+ )
+ (action_log, tool_res, return_direct) = intermediate_step
+ self.intermediate_steps.append((action_log, tool_res))
+ new_message = {
+ "role": "tool",
+ "name": tool_call.get("function").get("name"),
+ "content": tool_res,
+ }
+ if tool_call.get("id"):
+ new_message["tool_call_id"] = tool_call.get("id")
+
+ messages.append(new_message)
+ if return_direct:
+ if self.enable_streaming:
+ await self._stream_by_lines(tool_res)
+ self.streaming_callback.done.set()
+ return tool_res
+
+ self.messages = messages
+ kwargs["messages"] = self.messages
+ return await self._completion(**kwargs)
+
+ def _cleanup_output(self, output):
+ # anthropic returns a XML formatted response
+ # we need to get the content between tags
+ if self.llm_data.llm.provider == LLMProvider.ANTHROPIC:
+ from xmltodict import parse as xml_parse
+
+ xml_output = "" + output + ""
+ output = xml_parse(xml_output)
+ output = output["root"]
+
+ if isinstance(output, str):
+ return output
+ else:
+ if "result" in output:
+ output = output.get("result")
+ else:
+ output = output.get("#text")
+ return output
+
+ def _transform_completion_to_streaming(self, res, **kwargs):
+ # hacky way to convert non-streaming response to streaming response
+ if not kwargs.get("stream"):
+ for choice in res.choices:
+ choice.delta = choice.message
+ res = [res]
+ return res
+
+ async def _completion(self, **kwargs) -> Any:
+ logger.info(f"Calling LLM with kwargs: {kwargs}")
+ new_messages = self.messages
+
+ if kwargs.get("stream"):
+ await self.streaming_callback.on_llm_start()
+
+ should_stream_directly = (
+ self.enable_streaming
+ and self.llm_data.llm.provider
+ not in self.NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS
+ and self.llm_data.llm.provider != LLMProvider.ANTHROPIC
+ )
+
+ # TODO: Remove this when Groq and Bedrock supports streaming with tools
+ if self.llm_data.llm.provider in self.NOT_TOOLS_STREAMING_SUPPORTED_PROVIDERS:
+ logger.info(
+ f"Disabling streaming for {self.llm_data.llm.provider}, as tools are used"
+ )
+ kwargs["stream"] = False
+
+ res = completion(**kwargs)
+ res = self._transform_completion_to_streaming(res, **kwargs)
+
+ tool_calls = []
+ output = ""
+
+ for chunk in res:
+ new_message = chunk.choices[0].delta.dict()
+ # clean up tool calls
+ if new_message.get("tool_calls"):
+ new_message["role"] = "assistant"
+ new_tool_calls = new_message.get("tool_calls", [])
+ for tool_call in new_tool_calls:
+ tool_call["type"] = "function"
+ if "index" in tool_call:
+ del tool_call["index"]
+
+ new_messages.append(new_message)
+ tool_calls.extend(new_tool_calls)
+
+ content = new_message.get("content", "")
+
+ if content:
+ output += content
+ if should_stream_directly:
+ await self.streaming_callback.on_llm_new_token(content)
+
+ self.messages = new_messages
+
+ if tool_calls:
+ return await self._execute_tool_calls(tool_calls, **kwargs)
+
+ output = self._cleanup_output(output)
+
+ if not should_stream_directly:
+ await self._stream_by_lines(output)
+
+ if self.enable_streaming:
+ self.streaming_callback.done.set()
+
+ return output
+
+ async def ainvoke(self, input, *_, **kwargs):
+ self.input = input
+ self.messages = [
{
"content": self.prompt,
"role": "system",
@@ -103,22 +271,33 @@ def messages(self):
},
]
- async def get_agent(self):
- agent_executor = LLMAgentOpenAIFunctionCallingExecutor(**self.__dict__)
- return agent_executor
+ if self.enable_streaming:
+ for callback in kwargs["config"]["callbacks"]:
+ if isinstance(callback, CustomAsyncIteratorCallbackHandler):
+ self.streaming_callback = callback
+ if not self.streaming_callback:
+ raise Exception("Streaming Callback not found")
-class LLMAgentOpenAIFunctionCallingExecutor(LLMAgent):
- @property
- def tools(self):
- tools = get_tools(
- agent_data=self.agent_data,
- session_id=self.session_id,
+ output = await self._completion(
+ model=self.llm_data.model,
+ api_key=self.llm_data.llm.apiKey,
+ messages=self.messages,
+ tools=self.tools if len(self.tools) > 0 else None,
+ tool_choice="auto" if len(self.tools) > 0 else None,
+ stream=self.enable_streaming,
+ **self.llm_data.params.dict(exclude_unset=True),
)
- return [
- {"type": "function", "function": convert_to_openai_function(tool)}
- for tool in tools
- ]
+
+ return {
+ "intermediate_steps": self.intermediate_steps,
+ "input": self.input,
+ "output": output,
+ }
+
+
+class AgentExecutorOpenAIFunc(LLMAgent):
+ """Agent Executor that binded with OpenAI Function Calling"""
@property
def messages_function_calling(self):
@@ -133,10 +312,29 @@ def messages_function_calling(self):
},
]
+ @property
+ def messages(self):
+ return [
+ {
+ "content": self.prompt,
+ "role": "system",
+ },
+ {
+ "content": self.input,
+ "role": "user",
+ },
+ ]
+
async def ainvoke(self, input, *_, **kwargs):
self.input = input
- model = self.agent_data.metadata.get("model", "gpt-3.5-turbo-0125")
- tool_responses = []
+ tool_results = []
+ if self.enable_streaming:
+ for callback in kwargs["config"]["callbacks"]:
+ if isinstance(callback, CustomAsyncIteratorCallbackHandler):
+ self.streaming_callback = callback
+
+ if not self.streaming_callback:
+ raise Exception("Streaming Callback not found")
if len(self.tools) > 0:
openai_llm = await prisma.llm.find_first(
@@ -154,41 +352,45 @@ async def ainvoke(self, input, *_, **kwargs):
)
res = completion(
+ api_key=openai_api_key,
model="gpt-3.5-turbo-0125",
messages=self.messages_function_calling,
tools=self.tools,
stream=False,
- api_key=openai_api_key,
)
tool_calls = res.choices[0].message.get("tool_calls", [])
for tool_call in tool_calls:
- try:
- res = await call_tool(
- agent_data=self.agent_data,
- session_id=self.session_id,
- function=tool_call.function,
- )
- except Exception as e:
- logger.error(
- f"Error calling function {tool_call.function.name}: {e}"
- )
- continue
- tool_responses.append(res)
-
- if len(tool_responses) > 0:
+ (action_log, tool_res, return_direct) = await call_tool(
+ agent_data=self.agent_data,
+ session_id=self.session_id,
+ function=tool_call.function.dict(),
+ )
+ tool_results.append((action_log, tool_res))
+ if return_direct:
+ if self.enable_streaming:
+ await self._stream_by_lines(tool_res)
+ self.streaming_callback.done.set()
+
+ return {
+ "intermediate_steps": tool_results,
+ "input": self.input,
+ "output": tool_res,
+ }
+
+ if len(tool_results) > 0:
INPUT_TEMPLATE = "{input}\n Context: {context}\n"
self.input = INPUT_TEMPLATE.format(
input=self.input,
context="\n\n".join(
- [tool_response for (_, tool_response) in tool_responses]
+ [tool_response for (_, tool_response) in tool_results]
),
)
params = self.llm_data.params.dict(exclude_unset=True)
- res = await acompletion(
+ res = completion(
api_key=self.llm_data.llm.apiKey,
- model=model,
+ model=self.llm_data.model,
messages=self.messages,
stream=self.enable_streaming,
**params,
@@ -205,7 +407,7 @@ async def ainvoke(self, input, *_, **kwargs):
raise Exception("Streaming Callback not found")
await streaming_callback.on_llm_start()
- async for chunk in res:
+ for chunk in res:
token = chunk.choices[0].delta.content
if token:
output += token
@@ -216,7 +418,7 @@ async def ainvoke(self, input, *_, **kwargs):
output = res.choices[0].message.content
return {
- "intermediate_steps": tool_responses,
+ "intermediate_steps": tool_results,
"input": self.input,
"output": output,
}
diff --git a/libs/superagent/app/api/agents.py b/libs/superagent/app/api/agents.py
index 9204c7c74..20e7019e8 100644
--- a/libs/superagent/app/api/agents.py
+++ b/libs/superagent/app/api/agents.py
@@ -501,10 +501,14 @@ async def send_message(
from langchain.output_parsers.json import SimpleJsonOutputParser
parser = SimpleJsonOutputParser()
- parsed_schema = str(parser.parse(schema_tokens))
+ try:
+ parsed_res = parser.parse(schema_tokens)
+ except Exception as e:
+ logger.error(f"Error parsing output: {e}")
+ parsed_res = {}
# stream line by line to prevent streaming large data in one go
- for line in parsed_schema.split("\n"):
+ for line in json.dumps(parsed_res).split("\n"):
async for val in stream_dict_keys(
{"event": "message", "data": line}
):
@@ -603,8 +607,12 @@ async def send_message(
if output_schema:
from langchain.output_parsers.json import SimpleJsonOutputParser
- json_parser = SimpleJsonOutputParser()
- output["output"] = json_parser.parse(text=output["output"])
+ parser = SimpleJsonOutputParser()
+ try:
+ output["output"] = parser.parse(text=output["output"])
+ except Exception as e:
+ logger.error(f"Error parsing output: {e}")
+ output["output"] = {}
return {"success": True, "data": output}
diff --git a/libs/superagent/app/api/workflow_configs/saml_schema.py b/libs/superagent/app/api/workflow_configs/saml_schema.py
index 837da1242..decc22b97 100644
--- a/libs/superagent/app/api/workflow_configs/saml_schema.py
+++ b/libs/superagent/app/api/workflow_configs/saml_schema.py
@@ -149,6 +149,9 @@ class LLMAgentTool(BaseAgentToolModel, LLMAgent):
LLMProvider.TOGETHER_AI.value,
LLMProvider.ANTHROPIC.value,
LLMProvider.BEDROCK.value,
+ LLMProvider.GROQ.value,
+ LLMProvider.MISTRAL.value,
+ LLMProvider.COHERE_CHAT.value,
]
@@ -159,6 +162,9 @@ class Workflow(BaseModel):
perplexity: Optional[LLMAgent]
together_ai: Optional[LLMAgent]
bedrock: Optional[LLMAgent]
+ groq: Optional[LLMAgent]
+ mistral: Optional[LLMAgent]
+ cohere_chat: Optional[LLMAgent]
anthropic: Optional[LLMAgent]
llm: Optional[LLMAgent] = Field(
description="Deprecated! Use LLM providers instead. e.g. `perplexity` or `together_ai`"
diff --git a/libs/superagent/app/api/workflows.py b/libs/superagent/app/api/workflows.py
index 36c99572e..cef7ad71c 100644
--- a/libs/superagent/app/api/workflows.py
+++ b/libs/superagent/app/api/workflows.py
@@ -304,10 +304,15 @@ async def send_message() -> AsyncIterable[str]:
from langchain.output_parsers.json import SimpleJsonOutputParser
parser = SimpleJsonOutputParser()
- parsed_schema = str(parser.parse(schema_tokens))
+ try:
+ parsed_res = parser.parse(schema_tokens)
+ except Exception as e:
+ # TODO: stream schema parsing error as well
+ logger.error(f"Error in parsing schema: {e}")
+ parsed_res = {}
# stream line by line to prevent streaming large data in one go
- for line in parsed_schema.split("\n"):
+ for line in json.dumps(parsed_res).split("\n"):
agent_name = workflow_step["agent_name"]
async for val in stream_dict_keys(
{
diff --git a/libs/superagent/app/main.py b/libs/superagent/app/main.py
index bd5bfd13c..0491ecfb5 100644
--- a/libs/superagent/app/main.py
+++ b/libs/superagent/app/main.py
@@ -36,7 +36,7 @@
title="Superagent",
docs_url="/",
description="🥷 Run AI-agents with an API",
- version="0.2.29",
+ version="0.2.32",
servers=[{"url": config("SUPERAGENT_API_URL")}],
)
diff --git a/libs/superagent/app/models/tools.py b/libs/superagent/app/models/tools.py
index bc6d365f6..faa63ef62 100644
--- a/libs/superagent/app/models/tools.py
+++ b/libs/superagent/app/models/tools.py
@@ -1,6 +1,6 @@
from typing import Optional
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
class AlgoliaInput(BaseModel):
@@ -61,7 +61,7 @@ class E2BCodeExecutorInput(BaseModel):
class BrowserInput(BaseModel):
- url: str
+ url: str = Field(..., description="A valid url including protocol to analyze")
class GPTVisionInputModel(BaseModel):
diff --git a/libs/superagent/app/tools/__init__.py b/libs/superagent/app/tools/__init__.py
index 6c6fcebdc..5fa5a73cb 100644
--- a/libs/superagent/app/tools/__init__.py
+++ b/libs/superagent/app/tools/__init__.py
@@ -150,14 +150,14 @@ def conform_function_name(url):
"""
Validates OpenAI function names and modifies them to conform to the regex
"""
- regex_pattern = r"^[a-zA-Z0-9_-]{1,64}$"
+ regex_pattern = r"^[A-Za-z0-9_]{1,64}$"
# Check if the URL matches the regex
if re.match(regex_pattern, url):
return url # URL is already valid
else:
# Modify the URL to conform to the regex
- valid_url = re.sub(r"[^a-zA-Z0-9_-]", "", url)[:64]
+ valid_url = re.sub(r"[^A-Za-z0-9_]", "", url)[:64]
return valid_url
diff --git a/libs/superagent/app/tools/superrag.py b/libs/superagent/app/tools/superrag.py
index d312b3560..f1dfdcec1 100644
--- a/libs/superagent/app/tools/superrag.py
+++ b/libs/superagent/app/tools/superrag.py
@@ -1,3 +1,4 @@
+import json
import logging
from langchain_community.tools import BaseTool
@@ -57,7 +58,7 @@ async def _arun(
credentials = get_superrag_compatible_credentials(provider.options)
- return self.superrag_service.query(
+ res = self.superrag_service.query(
{
"vector_database": {"type": database_provider, "config": credentials},
"index_name": index_name,
@@ -67,3 +68,4 @@ async def _arun(
"interpreter_mode": interpreter_mode,
}
)
+ return json.dumps(res)
diff --git a/libs/superagent/app/utils/callbacks.py b/libs/superagent/app/utils/callbacks.py
index 02171ac24..21bc63770 100644
--- a/libs/superagent/app/utils/callbacks.py
+++ b/libs/superagent/app/utils/callbacks.py
@@ -42,7 +42,7 @@ async def on_agent_finish(self, finish: AgentFinish, **_: Any) -> Any:
while not self.queue.empty():
await asyncio.sleep(0.1)
- self.done.set()
+ self.done.set()
async def on_llm_start(self, *_: Any, **__: Any) -> None:
# If two calls are made in a row, this resets the state
@@ -93,6 +93,7 @@ async def aiter(self) -> AsyncIterator[str]:
if token_or_done is True:
continue
self.is_stream_started = True
+
yield token_or_done
diff --git a/libs/superagent/app/workflows/base.py b/libs/superagent/app/workflows/base.py
index 648207cf6..3ddc63eb6 100644
--- a/libs/superagent/app/workflows/base.py
+++ b/libs/superagent/app/workflows/base.py
@@ -1,3 +1,4 @@
+import logging
from typing import Any, List
from agentops.langchain_callback_handler import (
@@ -9,6 +10,8 @@
from app.agents.base import AgentFactory
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
+logger = logging.getLogger(__name__)
+
class WorkflowBase:
def __init__(
@@ -53,10 +56,14 @@ async def arun(self, input: Any):
)
if output_schema:
# TODO: throw error if output is not valid
- json_parser = SimpleJsonOutputParser()
- agent_response["output"] = json_parser.parse(
- text=agent_response["output"]
- )
+ parser = SimpleJsonOutputParser()
+ try:
+ agent_response["output"] = parser.parse(
+ text=agent_response["output"]
+ )
+ except Exception as e:
+ logger.error(f"Error parsing output: {e}")
+ agent_response["output"] = {}
previous_output = agent_response.get("output")
steps_output.append(agent_response)
diff --git a/libs/superagent/poetry.lock b/libs/superagent/poetry.lock
index a5f2389a3..9857cdffe 100644
--- a/libs/superagent/poetry.lock
+++ b/libs/superagent/poetry.lock
@@ -2311,13 +2311,13 @@ requests = ">=2,<3"
[[package]]
name = "litellm"
-version = "1.35.2"
+version = "1.35.21"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
files = [
- {file = "litellm-1.35.2-py3-none-any.whl", hash = "sha256:686ee040154d7062b0078d882fa6399c5c7cc5ec9b5266490dee68f1b8905a36"},
- {file = "litellm-1.35.2.tar.gz", hash = "sha256:062e5be75196da7348ae0c4f60d396f0b23ee874708ed81c40f7675161213385"},
+ {file = "litellm-1.35.21-py3-none-any.whl", hash = "sha256:907230b7ff57c853e32d04274c2bb01f75e77d49220bd3d4d8fa02cfe6d3492a"},
+ {file = "litellm-1.35.21.tar.gz", hash = "sha256:be0f9452fa357996e194c88eebc94f742be2fa623afd137a91b1e60ce5c3821f"},
]
[package.dependencies]
@@ -6025,4 +6025,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.8.1, <3.12"
-content-hash = "c390f22730e24482e7f42cc8140d339a1025fb1e25021c6d886cd4dafab3a622"
+content-hash = "9049d2eda40cf7a7809de8eeac32efac0753e9006d50b8ef98aca4ef75f0e703"
diff --git a/libs/superagent/prisma/migrations/20240418181431_add_mistral/migration.sql b/libs/superagent/prisma/migrations/20240418181431_add_mistral/migration.sql
new file mode 100644
index 000000000..fa30cbda9
--- /dev/null
+++ b/libs/superagent/prisma/migrations/20240418181431_add_mistral/migration.sql
@@ -0,0 +1,2 @@
+-- AlterEnum
+ALTER TYPE "LLMProvider" ADD VALUE 'MISTRAL';
diff --git a/libs/superagent/prisma/migrations/20240418183001_add_groq/migration.sql b/libs/superagent/prisma/migrations/20240418183001_add_groq/migration.sql
new file mode 100644
index 000000000..62f59f9fc
--- /dev/null
+++ b/libs/superagent/prisma/migrations/20240418183001_add_groq/migration.sql
@@ -0,0 +1,2 @@
+-- AlterEnum
+ALTER TYPE "LLMProvider" ADD VALUE 'GROQ';
diff --git a/libs/superagent/prisma/migrations/20240420075553_add_cohere/migration.sql b/libs/superagent/prisma/migrations/20240420075553_add_cohere/migration.sql
new file mode 100644
index 000000000..3cb444fb4
--- /dev/null
+++ b/libs/superagent/prisma/migrations/20240420075553_add_cohere/migration.sql
@@ -0,0 +1,2 @@
+-- AlterEnum
+ALTER TYPE "LLMProvider" ADD VALUE 'COHERE_CHAT';
\ No newline at end of file
diff --git a/libs/superagent/prisma/schema.prisma b/libs/superagent/prisma/schema.prisma
index 6f03bca21..c3125752d 100644
--- a/libs/superagent/prisma/schema.prisma
+++ b/libs/superagent/prisma/schema.prisma
@@ -24,6 +24,9 @@ enum LLMProvider {
TOGETHER_AI
ANTHROPIC
BEDROCK
+ GROQ
+ MISTRAL
+ COHERE_CHAT
}
enum LLMModel {
diff --git a/libs/superagent/pyproject.toml b/libs/superagent/pyproject.toml
index 63c75df77..7854e3041 100644
--- a/libs/superagent/pyproject.toml
+++ b/libs/superagent/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "superagent"
-version = "0.2.29"
+version = "0.2.32"
description = "🥷 Run AI-agents with an API"
authors = ["Ismail Pelaseyed"]
readme = "../../README.md"
@@ -50,7 +50,7 @@ openai = "^1.1.1"
langchain-experimental = "^0.0.37"
pydub = "^0.25.1"
algoliasearch = "^3.0.0"
-litellm = "1.35.2"
+litellm = "1.35.21"
weaviate-client = "^3.25.3"
qdrant-client = "^1.6.9"
vecs = "^0.4.2"
diff --git a/libs/ui/app/integrations/client-page.tsx b/libs/ui/app/integrations/client-page.tsx
index bd6e3c976..e37568a7f 100644
--- a/libs/ui/app/integrations/client-page.tsx
+++ b/libs/ui/app/integrations/client-page.tsx
@@ -15,7 +15,7 @@ export default function IntegrationsClientPage({
configuredLLMs: any
}) {
return (
-
+
STORAGE
diff --git a/libs/ui/app/integrations/llm.tsx b/libs/ui/app/integrations/llm.tsx
index 8f2bf5ba3..c375763cb 100644
--- a/libs/ui/app/integrations/llm.tsx
+++ b/libs/ui/app/integrations/llm.tsx
@@ -54,6 +54,25 @@ const antrophicSchema = z.object({
apiKey: z.string().nonempty("API key is required"),
options: z.object({}),
})
+
+const groqSchema = z.object({
+ llmType: z.literal(LLMProvider.GROQ),
+ apiKey: z.string().nonempty("API key is required"),
+ options: z.object({}),
+})
+
+const mistralSchema = z.object({
+ llmType: z.literal(LLMProvider.MISTRAL),
+ apiKey: z.string().nonempty("API key is required"),
+ options: z.object({}),
+})
+
+const cohereSchema = z.object({
+ llmType: z.literal(LLMProvider.COHERE_CHAT),
+ apiKey: z.string().nonempty("API key is required"),
+ options: z.object({}),
+})
+
const amazonBedrockSchema = z.object({
llmType: z.literal(LLMProvider.BEDROCK),
apiKey: z.literal(""),
@@ -79,6 +98,9 @@ const formSchema = z.discriminatedUnion("llmType", [
perplexityAiSchema,
togetherAiSchema,
antrophicSchema,
+ groqSchema,
+ mistralSchema,
+ cohereSchema,
amazonBedrockSchema,
azureOpenAiSchema,
])
diff --git a/libs/ui/config/saml.ts b/libs/ui/config/saml.ts
index 9e92d2045..156c3298c 100644
--- a/libs/ui/config/saml.ts
+++ b/libs/ui/config/saml.ts
@@ -7,7 +7,7 @@ workflows:
name: Browser assistant
intro: |-
👋 Hi there! How can I help search for answers on the internet.
- prompt: Use the browser to answer any questions
+ prompt: Use the browser tool to answer any questions
tools:
- browser:
name: browser
@@ -22,7 +22,7 @@ workflows:
- superagent:
name: Browser assistant
llm: gpt-3.5-turbo-16k-0613
- prompt: Use the browser to answer all questions
+ prompt: Use the browser tool to answer all questions
intro: 👋 Hi there! How can I help you?
tools:
- browser:
diff --git a/libs/ui/config/site.ts b/libs/ui/config/site.ts
index 3d2088966..f65099579 100644
--- a/libs/ui/config/site.ts
+++ b/libs/ui/config/site.ts
@@ -500,6 +500,19 @@ export const siteConfig = {
},
],
},
+ {
+ disabled: false,
+ formDescription: "Please enter your Groq API key.",
+ provider: LLMProvider.GROQ,
+ name: "Groq",
+ metadata: [
+ {
+ key: "apiKey",
+ type: "input",
+ label: "Groq API Key",
+ },
+ ],
+ },
{
disabled: false,
formDescription: "Please enter your AWS credentials.",
@@ -523,6 +536,32 @@ export const siteConfig = {
},
],
},
+ {
+ disabled: false,
+ formDescription: "Please enter your Mistral API key.",
+ provider: LLMProvider.MISTRAL,
+ name: "Mistral",
+ metadata: [
+ {
+ key: "apiKey",
+ type: "input",
+ label: "Mistral API Key",
+ },
+ ],
+ },
+ {
+ disabled: false,
+ formDescription: "Please enter your Cohere API key.",
+ provider: LLMProvider.COHERE_CHAT,
+ name: "Cohere",
+ metadata: [
+ {
+ key: "apiKey",
+ type: "input",
+ label: "Cohere API Key",
+ },
+ ],
+ },
{
disabled: false,
formDescription: "Please enter your Azure OpenAI API key.",
diff --git a/libs/ui/models/models.ts b/libs/ui/models/models.ts
index 414c0758a..7e0cfc67f 100644
--- a/libs/ui/models/models.ts
+++ b/libs/ui/models/models.ts
@@ -4,6 +4,9 @@ export const LLMProvider = {
TOGETHER_AI: "TOGETHER_AI",
ANTHROPIC: "ANTHROPIC",
BEDROCK: "BEDROCK",
+ GROQ: "GROQ",
+ MISTRAL: "MISTRAL",
+ COHERE_CHAT: "COHERE_CHAT",
AZURE_OPENAI: "AZURE_OPENAI",
} as const