Skip to content

Commit

Permalink
Update handler in support of v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Sep 30, 2024
1 parent 9ba52e7 commit 25bb516
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,18 @@
"Microsoft",
"MIT",
],
"empower-dev/llama3-empower-functions-small-v1.1": [
"Empower-Fucntions-Small-v1.1 (FC)",
"https://huggingface.co/empower-dev/llama3-empower-functions-small-v1.1",
"Empower.dev",
"apache-2.0"
],
"empower-dev/llama3-empower-functions-large-v1.1": [
"Empower-Fucntions-Large-v1.1 (FC)",
"https://huggingface.co/empower-dev/llama3-empower-functions-large-v1.1",
"Empower.dev",
"apache-2.0"
]
}

INPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from bfcl.model_handler.oss_model.deepseek import DeepseekHandler
from bfcl.model_handler.oss_model.empower import EmpowerHandler
from bfcl.model_handler.oss_model.gemma import GemmaHandler
from bfcl.model_handler.oss_model.glaive import GlaiveHandler
from bfcl.model_handler.oss_model.glm import GLMHandler
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler
from bfcl.model_handler.model_style import ModelStyle
import json
from bfcl.model_handler.utils import (
convert_to_tool,
)
from bfcl.model_handler.constant import (
GORILLA_TO_OPENAPI,
)


class EmpowerHandler(OSSHandler):
def __init__(self, model_name, temperature) -> None:
super().__init__(model_name, temperature)

def _preprocess_messages(self, messages):
# remove system message
messages = [
message for message in messages if message['role'] != "system"]

# combine tool responses
result = []
temp_tool_content = None
for message in messages:
if message['role'] == 'tool':
decoded_content = json.loads(message['content'])
if temp_tool_content:
temp_tool_content.append(decoded_content)
else:
temp_tool_content = [decoded_content]
else:
if temp_tool_content:
result.append({
'role': 'tool',
'content': json.dumps(temp_tool_content, indent=2)
})
temp_tool_content = None
result.append(message)
if temp_tool_content:
result.append({
'role': 'tool',
'content': json.dumps(temp_tool_content, indent=2)
})

return result

def _format_prompt(self, messages, functions):
formatted_prompt = "<|begin_of_text|>"

for idx, message in enumerate(self._preprocess_messages(messages)):
if idx == 0:
tools = convert_to_tool(
functions, GORILLA_TO_OPENAPI, ModelStyle.OSSMODEL
)
message['content'] = "In this environment you have access to a set of functions defined in the JSON format you can use to address user's requests, use them if needed.\nFunctions:\n" \
+ json.dumps(tools, indent=2) \
+ "\n\n" \
+ "User Message:\n" \
+ message['content']
else:
if message['role'] == 'tool':
message['role'] = 'user'
message['content'] = '<r>' + message['content']
elif message['role'] == 'user' and not message['content'].startswith('<r>') and not message['content'].startswith('<u>'):
message['content'] = '<u>' + message['content']

formatted_prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}<|eot_id|>"

formatted_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n"

return formatted_prompt

def decode_ast(self, result, language="Python"):
if not result.startswith('<f>'):
return []

# strip the function/conversation tag <f>/<c>
result_stripped = result[3:]

decoded_output = []
for invoked_function in json.loads(result_stripped):
name = invoked_function["name"]
params = invoked_function["arguments"] if "arguments" in invoked_function else {
}
decoded_output.append({name: params})

return decoded_output

def decode_execute(self, result):
execution_list = []

for function_call in self.decode_ast(result):
for key, value in function_call.items():
argument_list = []
for k, v in value.items():
argument_list.append(f'{k}={repr(v)}')
execution_list.append(
f"{key}({','.join(argument_list)})"
)

return execution_list

0 comments on commit 25bb516

Please sign in to comment.