forked from zhaoyiran924/Adv-In-Context-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
179 lines (146 loc) · 7.29 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import string
import json
import argparse
# from rouge_scorer import rouge_scorer
from rouge_score import rouge_scorer
from transformers import AutoTokenizer
class GPTTokenizer:
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2", max_length=1e5)
def tokenize(self, s):
tokens = self.gpt_tokenizer.tokenize(s)
# GPT2 uses Byte-level BPE, which will include space as part of the word.
# But for the first word of a sentence, there is no space before it.
# So, we remove all the added spaces ("Ġ").
tokens = [t.lstrip("Ġ") for t in tokens]
return tokens
default_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
xlingual_tokenizer = GPTTokenizer()
xlingual_rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], tokenizer=xlingual_tokenizer)
# adapted the flowing from Squad v1.1 evaluation, without removing the articles.
def normalize_answer(s):
"""Lower text and remove punctuation, and extra whitespace."""
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return str(text).lower()
return white_space_fix(remove_punc(lower(s)))
def exact_match(prediction, ground_truth, xlingual=False):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def rouge(prediction, ground_truth, xlingual=False):
if xlingual:
scorer = xlingual_rouge_scorer
else:
scorer = default_rouge_scorer
scores = scorer.score(prediction=str(prediction), target=str(ground_truth))
return scores["rougeL"].fmeasure
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, xlingual=False):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth, xlingual=xlingual)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def compute_metrics(predictions, references, xlingual=False):
assert len(predictions) == len(references), f"# of predictions {len(predictions)} doesn't match # of references {len(references)}."
em, rougeL = 0, 0
for pred, gold in zip(predictions, references):
assert isinstance(gold, list)
em += metric_max_over_ground_truths(
exact_match, prediction=pred, ground_truths=gold, xlingual=xlingual
)
rougeL += metric_max_over_ground_truths(
rouge, prediction=pred, ground_truths=gold, xlingual=xlingual
)
em = 100.0 * em / len(references)
rougeL = 100.0 * rougeL / len(references)
metrics = {"exact_match": em, "rougeL": rougeL}
metrics = {k: round(v, 4) for k, v in metrics.items()}
return metrics
def compute_grouped_metrics(predictions, references, groups, xlingual=False):
assert len(predictions) == len(references) == len(groups)
examples_by_group = {}
for pred, gold, group in zip(predictions, references, groups):
if group not in examples_by_group:
examples_by_group[group] = []
examples_by_group[group].append((pred, gold))
results = {}
for group, group_examples in examples_by_group.items():
task_predictions, task_references = zip(*group_examples)
group_metrics = compute_metrics(task_predictions, task_references, xlingual=xlingual)
for metric, value in group_metrics.items():
results[f"{metric}_for_{group}"] = value
return results
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prediction_file", required=True,
help="Jsonl file with each line corresponding to a prediction. "
"Each json object should have an `id` and a `prediction` key.")
parser.add_argument(
"--reference_file", required=True,
help="Jsonl file with each line corresponding to a reference. "
"Each json object should have an `id` and a `references` key. "
"`task_id`, `task_category` and `task_track` are optional, which will be used to "
"compute the per-task performance, per-category performance and the performance for default (english) / xlingual Tracks.")
parser.add_argument(
"--output_file",
help="Jsonl file to write the results to.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
eval_instances = {}
with open(args.reference_file) as fin:
for line in fin:
instance = json.loads(line)
# if track is not provided in the refernce file, we use set the track to `default` and use the default tokenizer in rouge-score.
if "track" not in instance:
instance["track"] = "default"
eval_instances[instance["id"]] = instance
all_predictions = {}
with open(args.prediction_file) as fin:
for line in fin:
prediction = json.loads(line)
# all_predictions[prediction["id"]] = str(int(prediction["prediction"]))
all_predictions[prediction["id"]] = prediction["prediction"]
print("read predictions and references")
all_results = {}
# for track in ["default", "xlingual"]:
# for track in ["xlingual"]:
for track in ["default"]:
print("Evaluating track:", track)
instance_ids = [id for id, instance in eval_instances.items() if instance["track"] == track]
references = [eval_instances[id]["references"] for id in instance_ids]
predictions = []
missing_predictions = []
for id in instance_ids:
if id in all_predictions:
predictions.append(all_predictions[id])
else:
missing_predictions.append(id)
predictions.append("")
if missing_predictions:
print(f"No prediction for {len(missing_predictions)} instances. Use empty string as prediction.")
results = compute_metrics(predictions, references, xlingual=(track == "xlingual"))
print("======== Overall Metrics ========")
for metric, value in results.items():
print(f"{metric}: {value}")
all_results[f"{metric}_{track}_track"] = value
if "task_category" in eval_instances[instance_ids[0]]:
categories = ["_".join(eval_instances[id]["task_category"].lower().split()) for id in instance_ids]
results_per_category = compute_grouped_metrics(predictions, references, categories, xlingual=(track == "xlingual"))
print("======== Metrics per Category ========")
for metric, value in results_per_category.items():
print(f"{metric}: {value}")
all_results[f"{metric}_{track}_track"] = value
if "task_id" in eval_instances[instance_ids[0]]:
tasks = [eval_instances[id]["task_id"] for id in instance_ids]
results_per_task = compute_grouped_metrics(predictions, references, tasks, xlingual=(track == "xlingual"))
print("======== Metrics per Task ========")
for metric, value in results_per_task.items():
print(f"{metric}: {value}")
all_results[f"{metric}_{track}_track"] = value
if args.output_file:
with open(args.output_file, "w") as fout:
json.dump(all_results, fout, indent=2)