From 1511e157a6c0d91db173f84414c54d6f58647121 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Tue, 30 Apr 2024 22:12:39 +0400 Subject: [PATCH] add memory to llm agent --- libs/superagent/app/agents/base.py | 6 + libs/superagent/app/agents/langchain.py | 8 +- libs/superagent/app/agents/llm.py | 213 +++++++++++++++++++----- 3 files changed, 184 insertions(+), 43 deletions(-) diff --git a/libs/superagent/app/agents/base.py b/libs/superagent/app/agents/base.py index 305a9cb8a..b9e614d6e 100644 --- a/libs/superagent/app/agents/base.py +++ b/libs/superagent/app/agents/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from functools import cached_property from typing import Any, List, Optional from langchain.agents import AgentExecutor @@ -78,6 +79,11 @@ def prompt(self) -> Any: def tools(self) -> Any: ... + # TODO: Set a proper return type when we remove Langchain agent type + @cached_property + async def memory(self) -> Any: + ... + @abstractmethod def get_agent(self) -> AgentExecutor: ... diff --git a/libs/superagent/app/agents/langchain.py b/libs/superagent/app/agents/langchain.py index c5fc39e95..63209af6a 100644 --- a/libs/superagent/app/agents/langchain.py +++ b/libs/superagent/app/agents/langchain.py @@ -1,4 +1,5 @@ import datetime +from functools import cached_property from decouple import config from langchain.agents import AgentType, initialize_agent @@ -62,9 +63,8 @@ def _get_llm(self): max_tokens=llm_data.params.max_tokens, ) - async def _get_memory( - self, - ) -> None | MotorheadMemory | ConversationBufferWindowMemory: + @cached_property + async def memory(self) -> None | MotorheadMemory | ConversationBufferWindowMemory: # if memory is already set, in the main agent base class, return it if not self.session_id: raise ValueError("Session ID is required to initialize memory") @@ -95,7 +95,7 @@ async def _get_memory( async def get_agent(self): llm = self._get_llm() - memory = await self._get_memory() + memory = await self.memory tools = self.tools prompt = self.prompt diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index 476a3a8f6..fd1f10aa1 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -1,6 +1,9 @@ +import asyncio import datetime import json import logging +from dataclasses import dataclass +from functools import cached_property, partial from typing import Any from decouple import config @@ -12,9 +15,14 @@ get_llm_provider, get_supported_openai_params, stream_chunk_builder, + token_counter, ) from app.agents.base import AgentBase +from app.memory.base import BaseMessage +from app.memory.buffer_memory import BufferMemory +from app.memory.memory_stores.redis import RedisMemoryStore +from app.memory.message import MessageType from app.tools import get_tools from app.utils.callbacks import CustomAsyncIteratorCallbackHandler from app.utils.prisma import prisma @@ -27,30 +35,19 @@ logger = logging.getLogger(__name__) +@dataclass +class ToolCallResponse: + action_log: AgentActionMessageLog + result: Any + return_direct: bool = False + success: bool = True + + async def call_tool( agent_data: Agent, session_id: str, function: Any -) -> tuple[AgentActionMessageLog, Any]: +) -> ToolCallResponse: name = function.get("name") - try: - args = json.loads(function.get("arguments")) - except Exception as e: - logger.error(f"Error parsing function arguments for {name}: {e}") - raise e - - tools = get_tools( - agent_data=agent_data, - session_id=session_id, - ) - tool_to_call = None - for tool in tools: - if tool.name == name: - tool_to_call = tool - break - if not tool_to_call: - raise Exception(f"Function {name} not found in tools") - - logging.info(f"Calling tool {name} with arguments {args}") - + args = function.get("arguments") action_log = AgentActionMessageLog( tool=name, tool_input=args, @@ -67,15 +64,60 @@ async def call_tool( ) ], ) - tool_res = None + try: - tool_res = await tool_to_call._arun(**args) - logging.info(f"Tool {name} returned {tool_res}") + args = json.loads(args) except Exception as e: - tool_res = f"Error calling {tool_to_call.name} tool with arguments {args}: {e}" - logging.error(f"Error calling tool {name}: {e}") + msg = f"Error parsing function arguments for {name}: {e}" + logger.error(msg) + return ToolCallResponse( + action_log=action_log, + result=msg, + return_direct=False, + success=False, + ) + + tools = get_tools( + agent_data=agent_data, + session_id=session_id, + ) + + logging.info(f"Calling tool {name} with arguments {args}") + + tool_to_call = None + for tool in tools: + if tool.name == name: + tool_to_call = tool + break + + if not tool_to_call: + msg = f"Function {name} not found in tools, avaliable tool names: {', '.join([tool.name for tool in tools])}" + logger.error(msg) + return ToolCallResponse( + action_log=action_log, + result=msg, + return_direct=False, + success=False, + ) - return (action_log, tool_res, tool_to_call.return_direct) + try: + result = await tool_to_call._arun(**args) + logging.info(f"Tool {name} returned {result}") + return ToolCallResponse( + action_log=action_log, + result=result, + return_direct=tool_to_call.return_direct, + success=True, + ) + except Exception as e: + msg = f"Error calling {tool_to_call.name} tool with arguments {args}: {e}" + logger.error(msg) + return ToolCallResponse( + action_log=action_log, + result=msg, + return_direct=False, + success=False, + ) class LLMAgent(AgentBase): @@ -107,6 +149,22 @@ def tools(self): for tool in tools ] + @cached_property + def memory(self) -> BufferMemory: + redisMemoryStore = RedisMemoryStore( + uri=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"), + session_id=self.session_id, + ) + tokenizer_fn = partial(token_counter, model=self.llm_data.model) + + bufferMemory = BufferMemory( + memory_store=redisMemoryStore, + model=self.llm_data.model, + tokenizer_fn=tokenizer_fn, + ) + + return bufferMemory + @property def prompt(self): base_prompt = self.agent_data.prompt or DEFAULT_PROMPT @@ -123,6 +181,15 @@ def prompt(self): else: prompt = base_prompt + messages = self.memory.get_messages( + inital_token_usage=len(prompt), + ) + if len(messages) > 0: + prompt += "\n\n Previous messages: \n" + for message in messages: + prompt += ( + f"""{message.type.value.capitalize()}: {message.content}\n\n""" + ) return prompt @property @@ -178,22 +245,23 @@ async def _execute_tool_calls(self, tool_calls: list[dict], **kwargs): session_id=self.session_id, function=tool_call.get("function"), ) - (action_log, tool_res, return_direct) = intermediate_step - self.intermediate_steps.append((action_log, tool_res)) + self.intermediate_steps.append( + (intermediate_step.action_log, intermediate_step.result) + ) new_message = { "role": "tool", "name": tool_call.get("function").get("name"), - "content": tool_res, + "content": intermediate_step.result, } if tool_call.get("id"): new_message["tool_call_id"] = tool_call.get("id") messages.append(new_message) - if return_direct: + if intermediate_step.return_direct: if self.enable_streaming: - await self._stream_by_lines(tool_res) + await self._stream_by_lines(intermediate_step.result) self.streaming_callback.done.set() - return tool_res + return intermediate_step.result self.messages = messages kwargs["messages"] = self.messages @@ -278,7 +346,7 @@ async def _process_completion_response(self, res): if content: output += content if self._stream_directly: - await self.streaming_callback.on_llm_new_token(content) + await self._stream_by_lines(content) return (tool_calls, new_messages, output) @@ -306,6 +374,27 @@ async def _acompletion(self, **kwargs) -> Any: tool_calls, new_messages, output = result + if output: + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.AI, + content=output, + ) + ) + + if tool_calls: + await asyncio.gather( + [ + self.memory.aadd_message( + message=BaseMessage( + type=MessageType.TOOL_CALL, + content=json.dumps(tool_call), + ) + ) + for tool_call in tool_calls + ] + ) + self.messages = new_messages if tool_calls: @@ -334,6 +423,13 @@ async def ainvoke(self, input, *_, **kwargs): }, ] + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.HUMAN, + content=self.input, + ) + ) + if self.enable_streaming: self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", [])) @@ -396,6 +492,13 @@ async def ainvoke(self, input, *_, **kwargs): self.input = input tool_results = [] + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.HUMAN, + content=self.input, + ) + ) + if self.enable_streaming: self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", [])) @@ -423,22 +526,47 @@ async def ainvoke(self, input, *_, **kwargs): ) tool_calls = res.choices[0].message.get("tool_calls", []) + await asyncio.gather( + [ + self.memory.aadd_message( + message=BaseMessage( + type=MessageType.TOOL_CALL, + content=json.dumps(tool_call), + ) + ) + for tool_call in tool_calls + ] + ) + for tool_call in tool_calls: - (action_log, tool_res, return_direct) = await call_tool( + intermediate_step = await call_tool( agent_data=self.agent_data, session_id=self.session_id, function=tool_call.function.dict(), ) - tool_results.append((action_log, tool_res)) - if return_direct: + + # TODO: handle the failure in tool call case + # if not intermediate_step.success: + # self.memory.add_message( + # message=BaseMessage( + # type=MessageType.TOOL_RESULT, + # content=intermediate_step.result, + # ) + # ) + + tool_results.append( + (intermediate_step.action_log, intermediate_step.result) + ) + + if intermediate_step.return_direct: if self.enable_streaming: - await self._stream_by_lines(tool_res) + await self._stream_by_lines(intermediate_step.result) self.streaming_callback.done.set() return { "intermediate_steps": tool_results, "input": self.input, - "output": tool_res, + "output": intermediate_step.result, } if len(tool_results) > 0: @@ -473,6 +601,13 @@ async def ainvoke(self, input, *_, **kwargs): else: output = res.choices[0].message.content + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.AI, + content=output, + ) + ) + return { "intermediate_steps": tool_results, "input": self.input,