Skip to content

Commit

Permalink
add memory to llm agent
Browse files Browse the repository at this point in the history
  • Loading branch information
elisalimli committed Apr 30, 2024
1 parent a863272 commit a7527d7
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 43 deletions.
6 changes: 6 additions & 0 deletions libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, List, Optional

from langchain.agents import AgentExecutor
Expand Down Expand Up @@ -78,6 +79,11 @@ def prompt(self) -> Any:
def tools(self) -> Any:
...

# TODO: Set a proper return type when we remove Langchain agent type
@cached_property
async def memory(self) -> Any:
...

@abstractmethod
def get_agent(self) -> AgentExecutor:
...
Expand Down
8 changes: 4 additions & 4 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from functools import cached_property

from decouple import config
from langchain.agents import AgentType, initialize_agent
Expand Down Expand Up @@ -62,9 +63,8 @@ def _get_llm(self):
max_tokens=llm_data.params.max_tokens,
)

async def _get_memory(
self,
) -> None | MotorheadMemory | ConversationBufferWindowMemory:
@cached_property
async def memory(self) -> None | MotorheadMemory | ConversationBufferWindowMemory:
# if memory is already set, in the main agent base class, return it
if not self.session_id:
raise ValueError("Session ID is required to initialize memory")
Expand Down Expand Up @@ -95,7 +95,7 @@ async def _get_memory(

async def get_agent(self):
llm = self._get_llm()
memory = await self._get_memory()
memory = await self.memory
tools = self.tools
prompt = self.prompt

Expand Down
213 changes: 174 additions & 39 deletions libs/superagent/app/agents/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import datetime
import json
import logging
from dataclasses import dataclass
from functools import cached_property, partial
from typing import Any

from decouple import config
Expand All @@ -12,9 +15,14 @@
get_llm_provider,
get_supported_openai_params,
stream_chunk_builder,
token_counter,
)

from app.agents.base import AgentBase
from app.memory.base import BaseMessage
from app.memory.buffer_memory import BufferMemory
from app.memory.memory_stores.redis import RedisMemoryStore
from app.memory.message import MessageType
from app.tools import get_tools
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from app.utils.prisma import prisma
Expand All @@ -27,30 +35,19 @@
logger = logging.getLogger(__name__)


@dataclass
class ToolCallResponse:
action_log: AgentActionMessageLog
result: Any
return_direct: bool = False
success: bool = True


async def call_tool(
agent_data: Agent, session_id: str, function: Any
) -> tuple[AgentActionMessageLog, Any]:
) -> ToolCallResponse:
name = function.get("name")
try:
args = json.loads(function.get("arguments"))
except Exception as e:
logger.error(f"Error parsing function arguments for {name}: {e}")
raise e

tools = get_tools(
agent_data=agent_data,
session_id=session_id,
)
tool_to_call = None
for tool in tools:
if tool.name == name:
tool_to_call = tool
break
if not tool_to_call:
raise Exception(f"Function {name} not found in tools")

logging.info(f"Calling tool {name} with arguments {args}")

args = function.get("arguments")
action_log = AgentActionMessageLog(
tool=name,
tool_input=args,
Expand All @@ -67,15 +64,60 @@ async def call_tool(
)
],
)
tool_res = None

try:
tool_res = await tool_to_call._arun(**args)
logging.info(f"Tool {name} returned {tool_res}")
args = json.loads(args)
except Exception as e:
tool_res = f"Error calling {tool_to_call.name} tool with arguments {args}: {e}"
logging.error(f"Error calling tool {name}: {e}")
msg = f"Error parsing function arguments for {name}: {e}"
logger.error(msg)
return ToolCallResponse(
action_log=action_log,
result=msg,
return_direct=False,
success=False,
)

tools = get_tools(
agent_data=agent_data,
session_id=session_id,
)

logging.info(f"Calling tool {name} with arguments {args}")

tool_to_call = None
for tool in tools:
if tool.name == name:
tool_to_call = tool
break

if not tool_to_call:
msg = f"Function {name} not found in tools, avaliable tool names: {', '.join([tool.name for tool in tools])}"
logger.error(msg)
return ToolCallResponse(
action_log=action_log,
result=msg,
return_direct=False,
success=False,
)

return (action_log, tool_res, tool_to_call.return_direct)
try:
result = await tool_to_call._arun(**args)
logging.info(f"Tool {name} returned {result}")
return ToolCallResponse(
action_log=action_log,
result=result,
return_direct=tool_to_call.return_direct,
success=True,
)
except Exception as e:
msg = f"Error calling {tool_to_call.name} tool with arguments {args}: {e}"
logger.error(msg)
return ToolCallResponse(
action_log=action_log,
result=msg,
return_direct=False,
success=False,
)


class LLMAgent(AgentBase):
Expand Down Expand Up @@ -107,6 +149,22 @@ def tools(self):
for tool in tools
]

@cached_property
def memory(self) -> BufferMemory:
redisMemoryStore = RedisMemoryStore(
uri=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
session_id=self.session_id,
)
tokenizer_fn = partial(token_counter, model=self.llm_data.model)

bufferMemory = BufferMemory(
memory_store=redisMemoryStore,
model=self.llm_data.model,
tokenizer_fn=tokenizer_fn,
)

return bufferMemory

@property
def prompt(self):
base_prompt = self.agent_data.prompt or DEFAULT_PROMPT
Expand All @@ -123,6 +181,15 @@ def prompt(self):
else:
prompt = base_prompt

messages = self.memory.get_messages(
inital_token_usage=len(prompt),
)
if len(messages) > 0:
prompt += "\n\n Previous messages: \n"
for message in messages:
prompt += (
f"""{message.type.value.capitalize()}: {message.content}\n\n"""
)
return prompt

@property
Expand Down Expand Up @@ -178,22 +245,23 @@ async def _execute_tool_calls(self, tool_calls: list[dict], **kwargs):
session_id=self.session_id,
function=tool_call.get("function"),
)
(action_log, tool_res, return_direct) = intermediate_step
self.intermediate_steps.append((action_log, tool_res))
self.intermediate_steps.append(
(intermediate_step.action_log, intermediate_step.result)
)
new_message = {
"role": "tool",
"name": tool_call.get("function").get("name"),
"content": tool_res,
"content": intermediate_step.result,
}
if tool_call.get("id"):
new_message["tool_call_id"] = tool_call.get("id")

messages.append(new_message)
if return_direct:
if intermediate_step.return_direct:
if self.enable_streaming:
await self._stream_by_lines(tool_res)
await self._stream_by_lines(intermediate_step.result)
self.streaming_callback.done.set()
return tool_res
return intermediate_step.result

self.messages = messages
kwargs["messages"] = self.messages
Expand Down Expand Up @@ -278,7 +346,7 @@ async def _process_completion_response(self, res):
if content:
output += content
if self._stream_directly:
await self.streaming_callback.on_llm_new_token(content)
await self._stream_by_lines(content)

return (tool_calls, new_messages, output)

Expand Down Expand Up @@ -306,6 +374,27 @@ async def _acompletion(self, **kwargs) -> Any:

tool_calls, new_messages, output = result

if output:
await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.AI,
content=output,
)
)

if tool_calls:
await asyncio.gather(
*[
self.memory.aadd_message(
message=BaseMessage(
type=MessageType.TOOL_CALL,
content=json.dumps(tool_call),
)
)
for tool_call in tool_calls
]
)

self.messages = new_messages

if tool_calls:
Expand Down Expand Up @@ -334,6 +423,13 @@ async def ainvoke(self, input, *_, **kwargs):
},
]

await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.HUMAN,
content=self.input,
)
)

if self.enable_streaming:
self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", []))

Expand Down Expand Up @@ -396,6 +492,13 @@ async def ainvoke(self, input, *_, **kwargs):
self.input = input
tool_results = []

await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.HUMAN,
content=self.input,
)
)

if self.enable_streaming:
self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", []))

Expand Down Expand Up @@ -423,22 +526,47 @@ async def ainvoke(self, input, *_, **kwargs):
)

tool_calls = res.choices[0].message.get("tool_calls", [])
await asyncio.gather(
*[
self.memory.aadd_message(
message=BaseMessage(
type=MessageType.TOOL_CALL,
content=json.dumps(tool_call),
)
)
for tool_call in tool_calls
]
)

for tool_call in tool_calls:
(action_log, tool_res, return_direct) = await call_tool(
intermediate_step = await call_tool(
agent_data=self.agent_data,
session_id=self.session_id,
function=tool_call.function.dict(),
)
tool_results.append((action_log, tool_res))
if return_direct:

# TODO: handle the failure in tool call case
# if not intermediate_step.success:
# self.memory.add_message(
# message=BaseMessage(
# type=MessageType.TOOL_RESULT,
# content=intermediate_step.result,
# )
# )

tool_results.append(
(intermediate_step.action_log, intermediate_step.result)
)

if intermediate_step.return_direct:
if self.enable_streaming:
await self._stream_by_lines(tool_res)
await self._stream_by_lines(intermediate_step.result)
self.streaming_callback.done.set()

return {
"intermediate_steps": tool_results,
"input": self.input,
"output": tool_res,
"output": intermediate_step.result,
}

if len(tool_results) > 0:
Expand Down Expand Up @@ -473,6 +601,13 @@ async def ainvoke(self, input, *_, **kwargs):
else:
output = res.choices[0].message.content

await self.memory.aadd_message(
message=BaseMessage(
type=MessageType.AI,
content=output,
)
)

return {
"intermediate_steps": tool_results,
"input": self.input,
Expand Down

0 comments on commit a7527d7

Please sign in to comment.