Skip to content

Commit

Permalink
Merge pull request #94 from PathOnAI/tool_refactor
Browse files Browse the repository at this point in the history
1) add ToolRegistry, 2) only pass tools to function calling based age…
  • Loading branch information
IBMC265 authored Oct 1, 2024
2 parents feaf48d + a2274fa commit 87987d9
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 129 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,29 @@ cp .env.example .env
```

### (2) QuickStart
* use web agent to finish some task and save the workflow
* use prompting-based web agent to finish some task and save the workflow
```bash
python -m main --agent_type FunctionCallingAgent --starting_url https://www.google.com --goal 'search dining table' --plan 'search dining table' --log_folder log
python -m main --agent_type PromptAgent --starting_url https://www.google.com --goal 'search dining table' --plan 'search dining table' --log_folder log
python -m main --agent_type HighLevelPlanningAgent --starting_url https://www.airbnb.com --goal "set destination as San Francisco, then search the results" --plan "(1) enter the 'San Francisco' as destination, (2) and click search" --log_folder log
python -m main --agent_type ContextAwarePlanningAgent --starting_url https://www.google.com --goal 'search dining table' --plan 'search dining table' --log_folder log
python -m main --agent_type FunctionCallingAgent --starting_url https://www.google.com --goal 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --plan 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --log_folder log
python -m main --agent_type HighLevelPlanningAgent --starting_url https://www.google.com --goal 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --plan 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --log_folder log
python -m main --agent_type ContextAwarePlanningAgent --starting_url https://www.google.com --goal 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --plan 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --log_folder log
python -m main --agent_type FunctionCallingAgent --starting_url https://www.google.com --goal 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --plan 'Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"' --log_folder log
python -m prompting_main --agent_type PromptAgent --starting_url https://www.google.com --goal 'search dining table' --plan 'search dining table' --log_folder log
```
* we also provide function-calling-based web agent
```bash
python -m function_calling_main --agent_type FunctionCallingAgent --starting_url https://www.google.com --goal 'search dining table' --plan 'search dining table' --log_folder log
python -m function_calling_main --agent_type HighLevelPlanningAgent --starting_url https://www.google.com --goal 'search dining table' --plan 'search dining table' --log_folder log
python -m function_calling_main --agent_type ContextAwarePlanningAgent --starting_url https://www.google.com --goal 'search dining table' --plan 'search dining table' --log_folder log
```
https://www.loom.com/share/1018bcc4e21c4a7eb517b60c2931ee3c
https://www.loom.com/share/aa48256478714d098faac740239c9013
https://www.loom.com/share/89f5fa69b8cb49c8b6a60368ddcba103


* replay the workflow verified by the web agent
If you haven't used the web agent to try any tests yet, first copy our example.json file.
```bash
cp log/flow/example.json log/flow/steps.json
```
then you can replay the session
```bash
python litewebagent/utils/replay.py --log_folder log
python litewebagent/action/replay.py --log_folder log
```
* enable user agent interaction

Expand All @@ -70,9 +70,9 @@ https://www.loom.com/share/93e3490a6d684cddb0fbefce4813902a
### (3) test different input features
We use axtree by default. Alternatively, you can provide a comma-separated string listing the desired input feature types.
```bash
python -m main --agent_type FunctionCallingAgent --starting_url https://www.airbnb.com --goal 'set destination as San Francisco, then search the results' --plan '(1) enter the "San Francisco" as destination, (2) and click search' --log_folder log
python -m main --agent_type FunctionCallingAgent --starting_url https://www.airbnb.com --goal 'set destination as San Francisco, then search the results' --plan '(1) enter the "San Francisco" as destination, (2) and click search' --features interactive_elements --log_folder log
python -m main --agent_type FunctionCallingAgent --starting_url https://www.airbnb.com --goal 'set destination as San Francisco, then search the results' --plan '(1) enter the "San Francisco" as destination, (2) and click search' --features axtree,interactive_elements --log_folder log
python -m function_calling_main --agent_type FunctionCallingAgent --starting_url https://www.airbnb.com --goal 'set destination as San Francisco, then search the results' --plan '(1) enter the "San Francisco" as destination, (2) and click search' --log_folder log
python -m function_calling_main --agent_type FunctionCallingAgent --starting_url https://www.airbnb.com --goal 'set destination as San Francisco, then search the results' --plan '(1) enter the "San Francisco" as destination, (2) and click search' --features interactive_elements --log_folder log
python -m function_calling_main --agent_type FunctionCallingAgent --starting_url https://www.airbnb.com --goal 'set destination as San Francisco, then search the results' --plan '(1) enter the "San Francisco" as destination, (2) and click search' --features axtree,interactive_elements --log_folder log
```

### (4) search_agent
Expand Down
6 changes: 3 additions & 3 deletions examples/google_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dotenv import load_dotenv

_ = load_dotenv()
from litewebagent.agents.webagent import setup_web_agent
from litewebagent.agents.webagent import setup_function_calling_web_agent

agent_type = "FunctionCallingAgent"
starting_url = "https://www.google.com"
Expand All @@ -13,8 +13,8 @@
branching_factor = None
storage_state = 'state.json'

agent = setup_web_agent(starting_url, goal, model_name=model, agent_type=agent_type, features=features,
branching_factor=branching_factor, log_folder=log_folder, storage_state=storage_state)
agent = setup_function_calling_web_agent(starting_url, goal, model_name=model, agent_type=agent_type, features=features,
tool_names = ["navigation", "select_option", "upload_file"], branching_factor=branching_factor, log_folder=log_folder, storage_state=storage_state)
response = agent.send_prompt(plan)
print(response)
print(agent.messages)
45 changes: 45 additions & 0 deletions function_calling_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from dotenv import load_dotenv
import argparse

_ = load_dotenv()
from litewebagent.agents.webagent import setup_function_calling_web_agent

def main(args):
# Use the features from command-line arguments
features = args.features.split(',') if args.features else None
branching_factor = args.branching_factor if args.branching_factor else None

# Use the tool_names from command-line arguments
tool_names = args.tool_names.split(',') if args.tool_names else ["navigation", "select_option", "upload_file"]

agent = setup_function_calling_web_agent(args.starting_url, args.goal, model_name=args.model, agent_type=args.agent_type,
features=features, tool_names=tool_names, branching_factor=branching_factor, log_folder=args.log_folder,
storage_state=args.storage_state)

response = agent.send_prompt(args.plan)
print(response)
print(agent.messages)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run web automation tasks with different agent types.")
parser.add_argument('--agent_type', type=str, default="FunctionCallingAgent",
choices=["FunctionCallingAgent", "HighLevelPlanningAgent", "ContextAwarePlanningAgent"],
help="Type of agent to use (default: FunctionCallingAgent)")
parser.add_argument('--model', type=str, default="gpt-4o-mini",
help="Model to use for the agent (default: gpt-4o-mini)")
parser.add_argument('--starting_url', type=str, required=True,
help="Starting URL for the web automation task")
parser.add_argument('--plan', type=str, required=True,
help="Plan for the web automation task")
parser.add_argument('--goal', type=str, required=True,
help="Goal for the web automation task")
parser.add_argument('--storage_state', type=str, default="state.json",
help="Storage state json file")
parser.add_argument('--features', type=str, default="axtree",
help="Comma-separated list of features to use (default: axtree)")
parser.add_argument('--tool_names', type=str, default="navigation,select_option,upload_file",
help="Comma-separated list of tool names to use (default: navigation,select_option,upload_file)")
parser.add_argument('--branching_factor', type=int, default=None)
parser.add_argument('--log_folder', type=str, default='log', help='Path to the log folder')
args = parser.parse_args()
main(args)
5 changes: 2 additions & 3 deletions litewebagent/action/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@
extract_focused_element_bid,
)
from litewebagent.browser_env.extract_elements import extract_interactive_elements
from dotenv import load_dotenv
from openai import OpenAI
import os
import re
import json
from litewebagent.utils.utils import encode_image

from dotenv import load_dotenv
_ = load_dotenv()
from elevenlabs.client import ElevenLabs
from elevenlabs import play

# Initialize the Eleven Labs client
elevenlabs_client = ElevenLabs(api_key=os.getenv("ELEVEN_API_KEY"))
openai_client = OpenAI()
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
import argparse
from litewebagent.action.highlevel import HighLevelActionSet
from litewebagent.utils.playwright_manager import PlaywrightManager
Expand Down
4 changes: 1 addition & 3 deletions litewebagent/agents/PromptAgents/PromptAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ def send_prompt(self, plan: str) -> Dict:
summary = response.choices[0].message.content
return summary

def __init__(self, model_name, tools, available_tools, messages, goal, playwright_manager, log_folder):
def __init__(self, model_name, messages, goal, playwright_manager, log_folder):
self.model_name = model_name
self.tools = tools
self.available_tools = available_tools
self.messages = messages
self.goal = goal
self.playwright_manager = playwright_manager
Expand Down
4 changes: 0 additions & 4 deletions litewebagent/agents/SearchAgents/PromptSearchAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,13 @@ def __init__(
self,
starting_url: str,
model_name: str,
tools: List[str],
available_tools: List[str],
messages: List[Dict[str, Any]],
goal: str,
playwright_manager: PlaywrightManager,
log_folder
):
self.model_name = model_name
self.starting_url = starting_url
self.tools = tools
self.available_tools = available_tools
self.messages = messages
self.goal = goal
self.playwright_manager = playwright_manager
Expand Down
137 changes: 61 additions & 76 deletions litewebagent/agents/webagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from litewebagent.agents.FunctionCallingAgents.ContextAwarePlanningAgent import ContextAwarePlanningAgent
from litewebagent.agents.SearchAgents.PromptSearchAgent import PromptSearchAgent
from litewebagent.agents.PromptAgents.PromptAgent import PromptAgent
from litewebagent.tools.fc_functions import navigation, upload_file, select_option
from litewebagent.utils.utils import setup_logger
from litewebagent.utils.playwright_manager import setup_playwright
from litewebagent.tools.registry import ToolRegistry

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

Expand All @@ -30,81 +30,21 @@ def wrapper(task_description):
return wrapper


tools = [
{
"type": "function",
"function": {
"name": "navigation",
"description": "Perform a web navigation task, including click, search ",
"parameters": {
"type": "object",
"properties": {
"task_description": {
"type": "string",
"description": "The description of the web navigation task"
},
},
"required": [
"task_description",
]
}
}
},
{
"type": "function",
"function": {
"name": "upload_file",
"description": "upload file.",
"parameters": {
"type": "object",
"properties": {
"task_description": {
"type": "string",
"description": "The description of the web navigation task"
},
},
"required": [
"task_description",
]
}
}
},
{
"type": "function",
"function": {
"name": "select_option",
"description": "Select an option from a dropdown or list.",
"parameters": {
"type": "object",
"properties": {
"task_description": {
"type": "string",
"description": "The description of the option selection task"
}
},
"required": [
"task_description"
]
}
}
}
]


def setup_web_agent(starting_url, goal, model_name="gpt-4o-mini", agent_type="DemoAgent", features=['axtree'],
branching_factor=None, log_folder="log", storage_state='state.json'):
def setup_function_calling_web_agent(starting_url, goal, model_name="gpt-4o-mini", agent_type="DemoAgent",
features=['axtree'], tool_names = ["navigation", "select_option", "upload_file"],
branching_factor=None, log_folder="log", storage_state='state.json'):
logger = setup_logger(log_folder)
playwright_manager = setup_playwright(log_folder=log_folder, storage_state=storage_state)

if features is None:
features = DEFAULT_FEATURES

available_tools = {
"navigation": create_function_wrapper(navigation, features, branching_factor, playwright_manager, log_folder),
"upload_file": create_function_wrapper(upload_file, features, branching_factor, playwright_manager, log_folder),
"select_option": create_function_wrapper(select_option, features, branching_factor, playwright_manager,
log_folder),
}
tool_registry = ToolRegistry()
available_tools = {}
tools = []
for tool_name in tool_names:
available_tools[tool_name] = create_function_wrapper(tool_registry.get_tool(tool_name).func, features,
branching_factor, playwright_manager, log_folder)
tools.append(tool_registry.get_tool_description(tool_name))

messages = [
{
Expand Down Expand Up @@ -142,9 +82,6 @@ def setup_web_agent(starting_url, goal, model_name="gpt-4o-mini", agent_type="De
agent = FunctionCallingAgent(model_name=model_name, tools=tools, available_tools=available_tools,
messages=messages,
goal=goal, playwright_manager=playwright_manager, log_folder=log_folder)
elif agent_type == "PromptAgent":
agent = PromptAgent(model_name=model_name, tools=tools, available_tools=available_tools,
messages=messages, goal=goal, playwright_manager=playwright_manager, log_folder=log_folder)
elif agent_type == "HighLevelPlanningAgent":
agent = HighLevelPlanningAgent(model_name=model_name, tools=tools, available_tools=available_tools,
messages=messages, goal=goal, playwright_manager=playwright_manager,
Expand All @@ -160,6 +97,55 @@ def setup_web_agent(starting_url, goal, model_name="gpt-4o-mini", agent_type="De
return agent


def setup_prompting_web_agent(starting_url, goal, model_name="gpt-4o-mini", agent_type="DemoAgent", features=['axtree'],
branching_factor=None, log_folder="log", storage_state='state.json'):
logger = setup_logger(log_folder)
playwright_manager = setup_playwright(log_folder=log_folder, storage_state=storage_state)
if features is None:
features = DEFAULT_FEATURES

messages = [
{
"role": "system",
"content": """You are a web search agent designed to perform specific tasks on web pages as instructed by the user. Your primary objectives are:
1. Execute ONLY the task explicitly provided by the user.
2. Perform the task efficiently and accurately using the available functions.
3. If there are errors, retry using a different approach within the scope of the given task.
4. Once the current task is completed, stop and wait for further instructions.
Critical guidelines:
- Strictly limit your actions to the current task. Do not attempt additional tasks or next steps.
- Use only the functions provided to you. Do not attempt to use functions or methods that are not explicitly available.
- For navigation or interaction with page elements, always use the appropriate bid (browser element ID) when required by a function.
- Do not try to navigate to external websites or use URLs directly.
- If a task cannot be completed with the available functions, report the limitation rather than attempting unsupported actions.
- After completing a task, report its completion and await new instructions. Do not suggest or initiate further actions.
Remember: Your role is to execute the given task precisely as instructed, using only the provided functions and within the confines of the current web page. Do not exceed these boundaries under any circumstances."""
}
]
file_path = os.path.join(log_folder, 'flow', 'steps.json')
os.makedirs(os.path.dirname(file_path), exist_ok=True)
page = playwright_manager.get_page()
page.goto(starting_url)
# Maximize the window on macOS
page.set_viewport_size({"width": 1440, "height": 900})

with open(file_path, 'w') as file:
file.write(goal + '\n')
file.write(starting_url + '\n')

if agent_type == "PromptAgent":
agent = PromptAgent(model_name=model_name,
messages=messages, goal=goal, playwright_manager=playwright_manager, log_folder=log_folder)
else:
error_message = f"Unsupported agent type: {agent_type}. Please use 'FunctionCallingAgent', 'HighLevelPlanningAgent', 'ContextAwarePlanningAgent', 'PromptAgent' or 'PromptSearchAgent' ."
logger.error(error_message)
return {"error": error_message}
return agent


def setup_search_agent(starting_url, goal, model_name="gpt-4o-mini", agent_type="PromptSearchAgent",
features=['axtree'], branching_factor=None, log_folder="log", storage_state='state.json'):
logger = setup_logger(log_folder)
Expand Down Expand Up @@ -207,8 +193,7 @@ def setup_search_agent(starting_url, goal, model_name="gpt-4o-mini", agent_type=
file.write(starting_url + '\n')

if agent_type == "PromptSearchAgent":
agent = PromptSearchAgent(starting_url=starting_url, model_name=model_name, tools=tools,
available_tools=available_tools,
agent = PromptSearchAgent(starting_url=starting_url, model_name=model_name,
messages=messages, goal=goal, playwright_manager=playwright_manager,
log_folder=log_folder)
else:
Expand Down
Loading

0 comments on commit 87987d9

Please sign in to comment.