-
Notifications
You must be signed in to change notification settings - Fork 99
/
Copy pathpredict.py
115 lines (100 loc) · 4.36 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# predict.py
import os
import json
from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from peft import PeftModel
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--llm_ckp', type=str, help='checkpoint of LLM')
parser.add_argument('--lora_path', type=str, default=None, help='lora adapters path')
parser.add_argument('--data_path', type=str, help='data to predict, should be json-lines format')
parser.add_argument('--prompt_key', type=str, help='the key of prompts in the data file')
parser.add_argument('--target_key', type=str, help='the key of targets/labels in the data file')
parser.add_argument('--batch_size', type=int, help='batch size')
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.llm_ckp, trust_remote_code=True, device_map="auto").half()
if 'llama' in args.llm_ckp.lower() or 'alpaca' in args.llm_ckp.lower():
tokenizer = LlamaTokenizer.from_pretrained(
args.llm_ckp, trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(args.llm_ckp, trust_remote_code=True)
if args.lora_path:
print("Using LoRA!")
model = PeftModel.from_pretrained(model, args.lora_path).half()
else:
print("Using Original LLM!")
prompts, targets = [], []
with open(args.data_path, 'r') as f:
lines = f.readlines()
ds = [json.loads(line) for line in lines]
for d in ds:
prompts.append(d[args.prompt_key])
targets.append(d[args.target_key])
def predict(prompts):
if isinstance(prompts, str):
prompts = [prompts]
assert isinstance(prompts, list), 'input should be list of text'
# # 不加其他参数,不设置 padding,不设置 return pt。这样可以使得每条都保留自己的长度
# inputs = tokenizer(prompts, max_length=1024, truncation=True)
# 再来一次带 padding 的 tokenization
tokenizer.padding_side = 'left'
input_tensors = tokenizer(prompts, max_length=1024, padding=True, truncation=True, return_tensors='pt')
prompt_length = input_tensors.input_ids.shape[1]
input_tensors.to('cuda:0')
# 下面是 InternLM 专属 generate 参数
# outputs = model.generate(**input_tensors, max_new_tokens=200, # 按照指定格式,输出差不多就这么长,多了就不用输出了
# temperature=0.8,
# top_p=0.8,
# eos_token_id=(2, 103028),
# )
outputs = model.generate(**input_tensors, max_new_tokens=200, # 按照指定格式,输出差不多就这么长,多了就不用输出了
repetition_penalty=1.1
)
# 过滤掉 prompt 部分
real_outputs = []
for i,output in enumerate(outputs):
output = output[prompt_length:]
real_outputs.append(output)
results = tokenizer.batch_decode(real_outputs, skip_special_tokens=True)
return results
# 批量预测
bs = args.batch_size
predicted_results = []
empty_number = 0
total_number = 0
for i in tqdm(range(len(prompts)//bs + 1)):
batch = prompts[i * bs : (i+1) * bs]
if batch:
batch_results= predict(batch)
predicted_results.extend(batch_results)
# 打印着看看
for prompt, each in zip(batch[:2], batch_results[:2]):
print('\n*****prompt******')
print(prompt)
print(' ===prediction===>')
print(each)
for res in batch_results:
if res == '':
empty_number += 1
total_number += 1
print('---------------------------------------------')
print(f'empty number: {empty_number}/{total_number}')
print('---------------------------------------------')
if args.lora_path:
name1 = args.lora_path.split('/')[-1]
else:
name1 = args.llm_ckp.split('/')[-1]
name2 = args.data_path.split('/')[-1]
os.makedirs('data/eval', exist_ok=True)
with open(f'data/eval/{name1}-{name2}_predictions.json', 'w', encoding='utf8') as f:
for prompt, target, prediction in zip(prompts, targets, predicted_results):
line = {
'prompt': prompt,
'target': target,
'prediction': prediction
}
line = json.dumps(line, ensure_ascii=False)
f.write(line)
f.write('\n')
print(f'prediction file saved at [`data/eval_llama2/{name1}-{name2}_predictions.json`]')