Skip to content

Commit

Permalink
qwen debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Nov 14, 2024
1 parent 9a3ea5c commit f463abb
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 11 deletions.
4 changes: 2 additions & 2 deletions bench_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_engine(model_class: str, model_id: str, context_size: int = None):
if model_id == "Qwen/Qwen2.5-72B-Instruct":
model = VLLMEngine(
model_id="Qwen/Qwen2.5-72B-Instruct",
max_context_size=context_size,
max_context_size=context_size or 32768,
model_load_kwargs={
"tensor_parallel_size": 8,
# for more stability
Expand All @@ -109,7 +109,7 @@ def get_engine(model_class: str, model_id: str, context_size: int = None):
if model_id == "Qwen/Qwen2.5-7B-Instruct":
model = VLLMEngine(
model_id="Qwen/Qwen2.5-7B-Instruct",
max_context_size=context_size,
max_context_size=context_size or 32768,
model_load_kwargs={
"tensor_parallel_size": 8,
# for more stability
Expand Down
9 changes: 0 additions & 9 deletions experiments/fanoutqa/qwen/full/results.jsonl

This file was deleted.

155 changes: 155 additions & 0 deletions sandbox/debug_qwen.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-11-14T20:54:45.084276Z",
"start_time": "2024-11-14T20:54:45.080553Z"
}
},
"source": [
"import json\n",
"\n",
"from kani import ChatMessage\n",
"\n",
"msgs = json.loads(r\"\"\"[{\n",
" \"role\": \"user\",\n",
" \"content\": \"What are the names of 6 Metropolitan cities in Korea and their respective symbol flowers?\",\n",
" \"name\": null,\n",
" \"tool_call_id\": null,\n",
" \"tool_calls\": null,\n",
" \"is_tool_call_error\": null\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"\",\n",
" \"name\": null,\n",
" \"tool_call_id\": null,\n",
" \"tool_calls\": [\n",
" {\n",
" \"id\": \"e5cac7c8-d65a-4fc7-b271-504eeb2bd151\",\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"search\",\n",
" \"arguments\": \"{\\\"query\\\": \\\"6 Metropolitan cities in Korea and their symbol flowers\\\"}\"\n",
" }\n",
" }\n",
" ],\n",
" \"is_tool_call_error\": null\n",
" },\n",
" {\n",
" \"role\": \"function\",\n",
" \"content\": \"The function 'search' is not defined. Only use the provided functions.\",\n",
" \"name\": null,\n",
" \"tool_call_id\": \"e5cac7c8-d65a-4fc7-b271-504eeb2bd151\",\n",
" \"tool_calls\": null,\n",
" \"is_tool_call_error\": true\n",
" }]\"\"\")\n",
"msgs = [ChatMessage.model_validate(m) for m in msgs]"
],
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-14T20:54:46.329461Z",
"start_time": "2024-11-14T20:54:46.325084Z"
}
},
"cell_type": "code",
"source": "msgs",
"id": "c0820f67f497fc56",
"outputs": [
{
"data": {
"text/plain": [
"[ChatMessage(role=<ChatRole.USER: 'user'>, content='What are the names of 6 Metropolitan cities in Korea and their respective symbol flowers?', name=None, tool_call_id=None, tool_calls=None, is_tool_call_error=None),\n",
" ChatMessage(role=<ChatRole.ASSISTANT: 'assistant'>, content='', name=None, tool_call_id=None, tool_calls=[ToolCall(id='e5cac7c8-d65a-4fc7-b271-504eeb2bd151', type='function', function=FunctionCall(name='search', arguments='{\"query\": \"6 Metropolitan cities in Korea and their symbol flowers\"}'))], is_tool_call_error=None),\n",
" ChatMessage(role=<ChatRole.FUNCTION: 'function'>, content=\"The function 'search' is not defined. Only use the provided functions.\", name=None, tool_call_id='e5cac7c8-d65a-4fc7-b271-504eeb2bd151', tool_calls=None, is_tool_call_error=True)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-14T20:55:32.590683Z",
"start_time": "2024-11-14T20:55:31.103450Z"
}
},
"cell_type": "code",
"source": [
"from kani.engines.huggingface.chat_template_pipeline import ChatTemplatePromptPipeline\n",
"from transformers import AutoTokenizer\n",
"\n",
"tok = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-72B-Instruct\")\n",
"pipeline = ChatTemplatePromptPipeline(tok)"
],
"id": "fbebba433cb13f1c",
"outputs": [],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-14T20:55:39.950601Z",
"start_time": "2024-11-14T20:55:39.932885Z"
}
},
"cell_type": "code",
"source": "pipeline(msgs)",
"id": "968e0e2c9d686aab",
"outputs": [
{
"data": {
"text/plain": [
"'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n<|im_start|>user\\nWhat are the names of 6 Metropolitan cities in Korea and their respective symbol flowers?<|im_end|>\\n<|im_start|>assistant\\n<tool_call>\\n{\"name\": \"search\", \"arguments\": {\"query\": \"6 Metropolitan cities in Korea and their symbol flowers\"}}\\n</tool_call><|im_end|>\\n<|im_start|>user\\n<tool_response>\\nThe function \\'search\\' is not defined. Only use the provided functions.\\n</tool_response><|im_end|>\\n<|im_start|>assistant\\n'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 5
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
"id": "9722a34151605ca2"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
21 changes: 21 additions & 0 deletions sandbox/debug_qwen_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

prompt = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful"
" assistant.<|im_end|>\n<|im_start|>user\nWhat are the names of 6 Metropolitan cities in Korea and their respective"
' symbol flowers?<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{"name": "search", "arguments": {"query": "6'
" Metropolitan cities in Korea and their symbol"
" flowers\"}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\nThe function 'search' is not defined."
" Only use the provided functions.\n</tool_response><|im_end|>\n<|im_start|>assistant\n"
)

model_name = "Qwen/Qwen2.5-72B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

0 comments on commit f463abb

Please sign in to comment.