-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathindex.html
472 lines (426 loc) · 69.8 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>09_NLP_Evaluation</title>
<link rel="stylesheet" href="https://stackedit.io/style.css" />
</head>
<body class="stackedit">
<div class="stackedit__html"><h1 id="nlp-evaluation">09 NLP Evaluation</h1>
<h2 id="assignment">Assignment</h2>
<p>Pick any of your past code and:</p>
<ol>
<li>
<p>Implement the following metrics (either on separate models or same, your choice):</p>
<ol>
<li>Recall, Precision, and F1 Score</li>
<li>BLEU</li>
<li>Perplexity (explain whether you are using bigram, trigram, or something else, what does your PPL score represent?)</li>
<li>BERTScore (here are <a href="https://colab.research.google.com/drive/1kpL8Y_AnUUiCxFjhxSrxCsc6-sDMNb_Q">1 (Links to an external site.)</a> <a href="https://huggingface.co/metrics/bertscore">2 (Links to an external site.)</a> examples)</li>
</ol>
</li>
<li>
<p>Once done, proceed to answer questions in the Assignment-Submission Page.</p>
<p>Questions asked are:</p>
<ol>
<li>Share the link to the readme file where you have explained all 4 metrics.</li>
<li>Share the link(s) where we can find the code and training logs for all of your 4 metrics</li>
<li>Share the last 2-3 epochs/stage logs for all of your 4 metrics separately (A, B, C, D) and describe your understanding about the numbers you’re seeing, are they good/bad? Why?</li>
</ol>
</li>
</ol>
<h2 id="solution">Solution</h2>
<p><a href="https://github.com/extensive-nlp/ttc_nlp"><code>ttc_nlp</code></a>: This package was developed to keep models and datasets in an organized way. On every colab run this package is installed. It also makes sure of the package versions so there should be no breaking changes from now on.</p>
<h3 id="text-classification-model-and-evaluation">Text Classification Model and Evaluation</h3>
<div align="center">
<a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/ClassificationEvaluation.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a> | <a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/ClassificationEvaluation.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a>
</div>
<br>
<p>Dataset: <a href="https://nlp.stanford.edu/sentiment/index.html">SST</a></p>
<table>
<thead>
<tr>
<th>Model</th>
<th>Precision</th>
<th>Recall</th>
<th>F1</th>
</tr>
</thead>
<tbody>
<tr>
<td>LSTM</td>
<td>0.414</td>
<td>0.357</td>
<td>0.412</td>
</tr>
</tbody>
</table><pre><code>Test Epoch 7/9: F1 Score: 0.41271, Precision: 0.41481, Recall: 0.35720
Classification Report
precision recall f1-score support
very negative 0.31 0.10 0.15 270
negative 0.44 0.65 0.52 603
neutral 0.29 0.19 0.23 376
positive 0.38 0.57 0.46 491
very positive 0.66 0.27 0.39 385
accuracy 0.41 2125
macro avg 0.41 0.36 0.35 2125
weighted avg 0.42 0.41 0.38 2125
</code></pre>
<p>Confusion Matrix</p>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/sst_cm.png?raw=true" alt="confusion matrix"></p>
<pre><code>Test Epoch 9/9: F1 Score: 0.38965, Precision: 0.38578, Recall: 0.37095
Classification Report
precision recall f1-score support
very negative 0.35 0.27 0.30 270
negative 0.44 0.42 0.43 603
neutral 0.23 0.19 0.21 376
positive 0.36 0.52 0.43 491
very positive 0.55 0.45 0.50 385
accuracy 0.39 2125
macro avg 0.39 0.37 0.37 2125
weighted avg 0.39 0.39 0.39 2125
</code></pre>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/sst_cm_10.png?raw=true" alt="confusion matrix of 10th epoch"></p>
<p>I’ve taken two example to show something, Epoch 7 with F1 Score of 0.41 and Epoch 9 with F1 Score of 0.389.</p>
<p>So it seems like the F1 Score is decreasing, so the model is not learning, but something i did observe is that, classes Negative and Positive are more compared to the others. So even though F1 Score has decreased, but the “weighted” F1 score has increased from 0.38 to 0.39, as the model has started to focus on other classes as well, like “very negative” had only 27 correct classifications, but on 9th epoch it went to 73.</p>
<p>A Score of 0.41 is not that good, considering that people have gone upto 0.60+, but then those models are using <a href="https://paperswithcode.com/sota/sentiment-analysis-on-sst-5-fine-grained">Transformers, BiDirectional LSTM with CNN</a>. But our model is simple, just by using Augmentations we have achieved a pretty good accuracy I would say.</p>
<p><strong>Stat Scores</strong></p>
<p>To Compute Precision, Recall or F1 we basically need the True Positives, True Negatives, False Positives and False Negatives, the below functions from <a href="https://github.com/PyTorchLightning/metrics">torchmetrics</a> does that</p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">def</span> <span class="token function">_stat_scores</span><span class="token punctuation">(</span>
preds<span class="token punctuation">:</span> Tensor<span class="token punctuation">,</span>
target<span class="token punctuation">:</span> Tensor<span class="token punctuation">,</span>
<span class="token builtin">reduce</span><span class="token punctuation">:</span> Optional<span class="token punctuation">[</span><span class="token builtin">str</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token string">"micro"</span><span class="token punctuation">,</span>
<span class="token punctuation">)</span> <span class="token operator">-</span><span class="token operator">></span> Tuple<span class="token punctuation">[</span>Tensor<span class="token punctuation">,</span> Tensor<span class="token punctuation">,</span> Tensor<span class="token punctuation">,</span> Tensor<span class="token punctuation">]</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""Calculate the number of tp, fp, tn, fn.
Args:
preds:
An ``(N, C)`` or ``(N, C, X)`` tensor of predictions (0 or 1)
target:
An ``(N, C)`` or ``(N, C, X)`` tensor of true labels (0 or 1)
reduce:
One of ``'micro'``, ``'macro'``, ``'samples'``
Return:
Returns a list of 4 tensors; tp, fp, tn, fn.
The shape of the returned tensors depnds on the shape of the inputs
and the ``reduce`` parameter:
If inputs are of the shape ``(N, C)``, then
- If ``reduce='micro'``, the returned tensors are 1 element tensors
- If ``reduce='macro'``, the returned tensors are ``(C,)`` tensors
- If ``reduce'samples'``, the returned tensors are ``(N,)`` tensors
If inputs are of the shape ``(N, C, X)``, then
- If ``reduce='micro'``, the returned tensors are ``(N,)`` tensors
- If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors
- If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors
"""</span>
dim<span class="token punctuation">:</span> Union<span class="token punctuation">[</span><span class="token builtin">int</span><span class="token punctuation">,</span> List<span class="token punctuation">[</span><span class="token builtin">int</span><span class="token punctuation">]</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">1</span> <span class="token comment"># for "samples"</span>
<span class="token keyword">if</span> <span class="token builtin">reduce</span> <span class="token operator">==</span> <span class="token string">"micro"</span><span class="token punctuation">:</span>
dim <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token keyword">if</span> preds<span class="token punctuation">.</span>ndim <span class="token operator">==</span> <span class="token number">2</span> <span class="token keyword">else</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span>
<span class="token keyword">elif</span> <span class="token builtin">reduce</span> <span class="token operator">==</span> <span class="token string">"macro"</span><span class="token punctuation">:</span>
dim <span class="token operator">=</span> <span class="token number">0</span> <span class="token keyword">if</span> preds<span class="token punctuation">.</span>ndim <span class="token operator">==</span> <span class="token number">2</span> <span class="token keyword">else</span> <span class="token number">2</span>
true_pred<span class="token punctuation">,</span> false_pred <span class="token operator">=</span> target <span class="token operator">==</span> preds<span class="token punctuation">,</span> target <span class="token operator">!=</span> preds
pos_pred<span class="token punctuation">,</span> neg_pred <span class="token operator">=</span> preds <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">,</span> preds <span class="token operator">==</span> <span class="token number">0</span>
tp <span class="token operator">=</span> <span class="token punctuation">(</span>true_pred <span class="token operator">*</span> pos_pred<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>dim<span class="token operator">=</span>dim<span class="token punctuation">)</span>
fp <span class="token operator">=</span> <span class="token punctuation">(</span>false_pred <span class="token operator">*</span> pos_pred<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>dim<span class="token operator">=</span>dim<span class="token punctuation">)</span>
tn <span class="token operator">=</span> <span class="token punctuation">(</span>true_pred <span class="token operator">*</span> neg_pred<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>dim<span class="token operator">=</span>dim<span class="token punctuation">)</span>
fn <span class="token operator">=</span> <span class="token punctuation">(</span>false_pred <span class="token operator">*</span> neg_pred<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>dim<span class="token operator">=</span>dim<span class="token punctuation">)</span>
<span class="token keyword">return</span> tp<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> fp<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> tn<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> fn<span class="token punctuation">.</span><span class="token builtin">long</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
</code></pre>
<p><strong>Precision</strong></p>
<p><span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>Precision</mtext><mo>=</mo><mfrac><mrow><mi>T</mi><mi>P</mi></mrow><mrow><mi>T</mi><mi>P</mi><mo>+</mo><mi>F</mi><mi>P</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">\text{Precision} = \frac{TP}{TP+FP}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord text"><span class="mord">Precision</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.27566em; vertical-align: -0.403331em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.872331em;"><span class="" style="top: -2.655em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">TP</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">FP</span></span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.394em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">TP</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.403331em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
<p><strong>Recall</strong></p>
<p><span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>Recall</mtext><mo>=</mo><mfrac><mrow><mi>T</mi><mi>P</mi></mrow><mrow><mi>T</mi><mi>P</mi><mo>+</mo><mi>F</mi><mi>N</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">\text{Recall} = \frac{TP}{TP+FN}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord text"><span class="mord">Recall</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.27566em; vertical-align: -0.403331em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.872331em;"><span class="" style="top: -2.655em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">TP</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right: 0.10903em;">FN</span></span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.394em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">TP</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.403331em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
<p><strong>F1 Score</strong></p>
<p>Harmonic Mean of Precision and Recall</p>
<p><span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>F1</mtext><mo>=</mo><mfrac><mrow><mn>2</mn><mo>×</mo><mi>P</mi><mi>r</mi><mi>e</mi><mi>c</mi><mi>i</mi><mi>s</mi><mi>o</mi><mi>n</mi><mo>×</mo><mi>R</mi><mi>e</mi><mi>c</mi><mi>a</mi><mi>l</mi><mi>l</mi></mrow><mrow><mi>P</mi><mi>r</mi><mi>e</mi><mi>c</mi><mi>i</mi><mi>s</mi><mi>i</mi><mi>o</mi><mi>n</mi><mo>+</mo><mi>R</mi><mi>e</mi><mi>c</mi><mi>a</mi><mi>l</mi><mi>l</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">\text{F1} = \frac{2\times Precison\times Recall}{Precision+Recall}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord text"><span class="mord">F1</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.28344em; vertical-align: -0.403331em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.880108em;"><span class="" style="top: -2.655em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span><span class="mord mathnormal mtight">rec</span><span class="mord mathnormal mtight">i</span><span class="mord mathnormal mtight">s</span><span class="mord mathnormal mtight">i</span><span class="mord mathnormal mtight">o</span><span class="mord mathnormal mtight">n</span><span class="mbin mtight">+</span><span class="mord mathnormal mtight" style="margin-right: 0.00773em;">R</span><span class="mord mathnormal mtight">ec</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight" style="margin-right: 0.01968em;">ll</span></span></span></span><span class="" style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span class="" style="top: -3.394em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mbin mtight">×</span><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span><span class="mord mathnormal mtight">rec</span><span class="mord mathnormal mtight">i</span><span class="mord mathnormal mtight">so</span><span class="mord mathnormal mtight">n</span><span class="mbin mtight">×</span><span class="mord mathnormal mtight" style="margin-right: 0.00773em;">R</span><span class="mord mathnormal mtight">ec</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight" style="margin-right: 0.01968em;">ll</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.403331em;"><span class=""></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
<p><strong>Intuition behind Precision and Recall</strong></p>
<p>You can think of precision as the proportion of times that when you predict its positive it actually turns out to be positive. Where as recall can be thought of as accuracy over just the positives – it’s the proportion of times you labeled positive correctly over the amount of times it was actually positive.</p>
<p>In the multi-label case, precision and recall are usually applied on a per category basis. That is, if you are trying to guess whether a picture has a cat or dog or other animals, you would get precision and recall for your cats and dogs separately. Then it’s just the binary case again – if you want the precision for cats, you take the number of times you guessed correctly that it was cat / the total number of times that you guessed anything was a cat. Similarly, if you want to get recall for cats, you take the number of times you guessed correctly it was a cat over the total number of times it was actually a cat.</p>
<p>I like to think it this way: Precision is about how precise i am, right ? like how fine i can be, so i need to make correct predictions of the class from what all i predicted it to be that class. So lets say say i predicted 100 images as cats (there can be images of dogs that i can predict as cat), and out of those 50 images were actually cats, then i have a precision of 0.5.<br>
Recall is “just out of cat images”, how many i have got right! here i cannot include dog images !</p>
<h3 id="language-translation-model-and-evaluation">Language Translation Model and Evaluation</h3>
<div align="center">
<a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/TranslationTransformer.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a> | <a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/TranslationTransformer.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a>
</div>
<br>
<p>Dataset: <a href="https://github.com/multi30k/dataset">Multi30k</a></p>
<table>
<thead>
<tr>
<th>Model</th>
<th>PPL</th>
<th>BLEU Score</th>
<th>BERT Score</th>
</tr>
</thead>
<tbody>
<tr>
<td>Seq2Seq w/ Multi Head Transformer</td>
<td>7.572</td>
<td>32.758</td>
<td>P=0.942 R=0.939 F1=0.940</td>
</tr>
</tbody>
</table><p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/bleu_bert.png?raw=true" alt="blue_bert"></p>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/cross_entropy_ppl.png?raw=true" alt="crossentropy_ppl"></p>
<p><strong>Perplexity</strong></p>
<p>Perplexity comes from Information Theory, is a measurement of how well a probability distribution or probability model predicts a sample. It may be used to compare probability models. A low perplexity indicates the probability distribution is good at predicting the sample. [WikiPedia]</p>
<p><span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mi>P</mi><mi>L</mi><mo stretchy="false">(</mo><mi>p</mi><mo stretchy="false">)</mo><mo>=</mo><msup><mi>e</mi><mrow><mo>−</mo><msub><mo>∑</mo><mi>x</mi></msub><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mi>l</mi><mi>o</mi><msub><mi>g</mi><mi>e</mi></msub><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow></mrow></mrow></msup></mrow><annotation encoding="application/x-tex">PPL(p)=e^{-\sum_x{p(x)log_{e}{p(x)}}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">PP</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.888em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.888em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">−</span><span class="mspace mtight" style="margin-right: 0.195167em;"></span><span class="mop mtight"><span class="mop op-symbol small-op mtight" style="position: relative; top: -5e-06em;">∑</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist"><span class="" style="top: -2.17856em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">x</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.321439em;"><span class=""></span></span></span></span></span></span><span class="mspace mtight" style="margin-right: 0.195167em;"></span><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">x</span><span class="mclose mtight">)</span><span class="mord mathnormal mtight" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal mtight">o</span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.164543em;"><span class="" style="top: -2.357em; margin-left: -0.03588em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">e</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.143em;"><span class=""></span></span></span></span></span></span><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="mopen mtight">(</span><span class="mord mathnormal mtight">x</span><span class="mclose mtight">)</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></p>
<p>But you can observe that the exponent is Cross Entropy, Hence</p>
<p><span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>Cross Entropy</mtext><mo>=</mo><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mi>l</mi><mi>o</mi><msub><mi>g</mi><mi>e</mi></msub><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow></mrow><annotation encoding="application/x-tex">\text{Cross Entropy}=p(x)log_{e}{p(x)}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.87777em; vertical-align: -0.19444em;"></span><span class="mord text"><span class="mord">Cross Entropy</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">e</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mord"><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span></span><br>
<span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext>PPL</mtext><mo>=</mo><msup><mi>e</mi><mrow><mi>C</mi><mi>E</mi></mrow></msup></mrow><annotation encoding="application/x-tex">\text{PPL}=e^{CE}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord text"><span class="mord">PPL</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.841331em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.05764em;">CE</span></span></span></span></span></span></span></span></span></span></span></span></span></p>
<p>Intuitively, perplexity can be understood as a measure of uncertainty. The perplexity of a language model can be seen as the level of perplexity when predicting the following symbol. Consider a language model with an entropy of three bits, in which each bit encodes two possible outcomes of equal probability. This means that when predicting the next symbol, that language model has to choose among <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mn>2</mn><mn>3</mn></msup><mo>=</mo><mn>8</mn></mrow><annotation encoding="application/x-tex">2^3=8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord"><span class="mord">2</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">8</span></span></span></span></span> possible options. Thus, we can argue that this language model has a perplexity of <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>8</mn></mrow><annotation encoding="application/x-tex">8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">8</span></span></span></span></span>. <a href="https://thegradient.pub/understanding-evaluation-metrics-for-language-models/">Source</a></p>
<p>The PPL calculated for this model was in Unigram, which was <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>7.572</mn></mrow><annotation encoding="application/x-tex">7.572</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">7.572</span></span></span></span></span>, this would be interpreted as the model has to choose among <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mtext> </mtext><mn>8</mn></mrow><annotation encoding="application/x-tex">~8</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mspace nobreak"> </span><span class="mord">8</span></span></span></span></span> possible options of words to predict the next outcome. Its Good Enough ? <code>¯\_(ツ)_/¯</code></p>
<p><strong>BLEU Score</strong></p>
<p>There’s this nice interpretation of BLEU Score from <a href="https://cloud.google.com/translate/automl/docs/evaluate">Google Cloud</a></p>
<table>
<thead>
<tr>
<th>BLEU Score</th>
<th>Interpretation</th>
</tr>
</thead>
<tbody>
<tr>
<td>< 10</td>
<td>Almost useless</td>
</tr>
<tr>
<td>10 - 19</td>
<td>Hard to get the gist</td>
</tr>
<tr>
<td>20 - 29</td>
<td>The gist is clear, but has significant grammatical errors</td>
</tr>
<tr>
<td>30 - 40</td>
<td>Understandable to good translations</td>
</tr>
<tr>
<td>40 - 50</td>
<td>High quality translations</td>
</tr>
<tr>
<td>50 - 60</td>
<td>Very high quality, adequate, and fluent translations</td>
</tr>
<tr>
<td>> 60</td>
<td>Quality often better than human</td>
</tr>
</tbody>
</table><p>BLEU first makes n-grams (basically combine n words) from the predicted sentences and compare it with the n-grams of the actual target sentences. This matching is independent of the position of the n-gram. More the number of matches, more better the model is at translating.</p>
<p>We got a BLEU Score of <code>32.758</code>, so it comes under “Understandable to good translation”, and it is ! Note that this score was got from using unigram, bigram and trigram of the corpuses.</p>
<pre class=" language-python"><code class="prism language-python">translate<span class="token punctuation">(</span>transformer<span class="token punctuation">,</span> <span class="token string">"Eine Gruppe von Menschen steht vor einem Iglu ."</span><span class="token punctuation">)</span>
<span class="token operator">>></span><span class="token operator">></span>
A group of people stand <span class="token keyword">in</span> front of an outdoor airport <span class="token punctuation">.</span>
</code></pre>
<p>On Google Translate this gives</p>
<pre><code>A group of people stands in front of an igloo
</code></pre>
<p>So the model got everything other than the igloo, quite possibly because it would have not encountered this meaning before.</p>
<p>Implementation:</p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">def</span> <span class="token function">bleu_score</span><span class="token punctuation">(</span>candidate_corpus<span class="token punctuation">,</span> references_corpus<span class="token punctuation">,</span> max_n<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">0.25</span><span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""Computes the BLEU score between a candidate translation corpus and a references
translation corpus. Based on https://www.aclweb.org/anthology/P02-1040.pdf
Arguments:
candidate_corpus: an iterable of candidate translations. Each translation is an
iterable of tokens
references_corpus: an iterable of iterables of reference translations. Each
translation is an iterable of tokens
max_n: the maximum n-gram we want to use. E.g. if max_n=3, we will use unigrams,
bigrams and trigrams
weights: a list of weights used for each n-gram category (uniform by default)
Examples:
>>> from torchtext.data.metrics import bleu_score
>>> candidate_corpus = [['My', 'full', 'pytorch', 'test'], ['Another', 'Sentence']]
>>> references_corpus = [[['My', 'full', 'pytorch', 'test'], ['Completely', 'Different']], [['No', 'Match']]]
>>> bleu_score(candidate_corpus, references_corpus)
0.8408964276313782
"""</span>
<span class="token keyword">assert</span> max_n <span class="token operator">==</span> <span class="token builtin">len</span><span class="token punctuation">(</span>weights<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'Length of the "weights" list has be equal to max_n'</span>
<span class="token keyword">assert</span> <span class="token builtin">len</span><span class="token punctuation">(</span>candidate_corpus<span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token builtin">len</span><span class="token punctuation">(</span>references_corpus<span class="token punctuation">)</span><span class="token punctuation">,</span>\
<span class="token string">'The length of candidate and reference corpus should be the same'</span>
clipped_counts <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>max_n<span class="token punctuation">)</span>
total_counts <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>max_n<span class="token punctuation">)</span>
weights <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>weights<span class="token punctuation">)</span>
candidate_len <span class="token operator">=</span> <span class="token number">0.0</span>
refs_len <span class="token operator">=</span> <span class="token number">0.0</span>
<span class="token keyword">for</span> <span class="token punctuation">(</span>candidate<span class="token punctuation">,</span> refs<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">zip</span><span class="token punctuation">(</span>candidate_corpus<span class="token punctuation">,</span> references_corpus<span class="token punctuation">)</span><span class="token punctuation">:</span>
candidate_len <span class="token operator">+=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>candidate<span class="token punctuation">)</span>
<span class="token comment"># Get the length of the reference that's closest in length to the candidate</span>
refs_len_list <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>ref<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> ref <span class="token keyword">in</span> refs<span class="token punctuation">]</span>
refs_len <span class="token operator">+=</span> <span class="token builtin">min</span><span class="token punctuation">(</span>refs_len_list<span class="token punctuation">,</span> key<span class="token operator">=</span><span class="token keyword">lambda</span> x<span class="token punctuation">:</span> <span class="token builtin">abs</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>candidate<span class="token punctuation">)</span> <span class="token operator">-</span> x<span class="token punctuation">)</span><span class="token punctuation">)</span>
reference_counters <span class="token operator">=</span> _compute_ngram_counter<span class="token punctuation">(</span>refs<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> max_n<span class="token punctuation">)</span>
<span class="token keyword">for</span> ref <span class="token keyword">in</span> refs<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">:</span>
reference_counters <span class="token operator">=</span> reference_counters <span class="token operator">|</span> _compute_ngram_counter<span class="token punctuation">(</span>ref<span class="token punctuation">,</span> max_n<span class="token punctuation">)</span>
candidate_counter <span class="token operator">=</span> _compute_ngram_counter<span class="token punctuation">(</span>candidate<span class="token punctuation">,</span> max_n<span class="token punctuation">)</span>
clipped_counter <span class="token operator">=</span> candidate_counter <span class="token operator">&</span> reference_counters
<span class="token keyword">for</span> ngram <span class="token keyword">in</span> clipped_counter<span class="token punctuation">:</span>
clipped_counts<span class="token punctuation">[</span><span class="token builtin">len</span><span class="token punctuation">(</span>ngram<span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">+=</span> clipped_counter<span class="token punctuation">[</span>ngram<span class="token punctuation">]</span>
<span class="token keyword">for</span> ngram <span class="token keyword">in</span> candidate_counter<span class="token punctuation">:</span> <span class="token comment"># TODO: no need to loop through the whole counter</span>
total_counts<span class="token punctuation">[</span><span class="token builtin">len</span><span class="token punctuation">(</span>ngram<span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">+=</span> candidate_counter<span class="token punctuation">[</span>ngram<span class="token punctuation">]</span>
<span class="token keyword">if</span> <span class="token builtin">min</span><span class="token punctuation">(</span>clipped_counts<span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> <span class="token number">0.0</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
pn <span class="token operator">=</span> clipped_counts <span class="token operator">/</span> total_counts
log_pn <span class="token operator">=</span> weights <span class="token operator">*</span> torch<span class="token punctuation">.</span>log<span class="token punctuation">(</span>pn<span class="token punctuation">)</span>
score <span class="token operator">=</span> torch<span class="token punctuation">.</span>exp<span class="token punctuation">(</span><span class="token builtin">sum</span><span class="token punctuation">(</span>log_pn<span class="token punctuation">)</span><span class="token punctuation">)</span>
bp <span class="token operator">=</span> math<span class="token punctuation">.</span>exp<span class="token punctuation">(</span><span class="token builtin">min</span><span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> refs_len <span class="token operator">/</span> candidate_len<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> bp <span class="token operator">*</span> score<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
</code></pre>
<p><strong>BERT Score</strong></p>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/Architecture_BERTScore.png?raw=true" alt="bertscore architecture"></p>
<p>BertScore basically addresses two common pitfalls in n-gram-based metrics. Firstly, the n-gram models fail to robustly match paraphrases which leads to performance underestimation when semantically-correct phrases are penalized because of their difference from the surface form of the reference.</p>
<p>Each token in <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span> is matched to the most similar token in <span class="katex--inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mover accent="true"><mi>x</mi><mo>^</mo></mover></mrow><annotation encoding="application/x-tex">\hat{x}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord accent"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.69444em;"><span class="" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mathnormal">x</span></span><span class="" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="accent-body" style="left: -0.22222em;"><span class="mord">^</span></span></span></span></span></span></span></span></span></span></span> and vice-versa for calculating Recall and Precision respectively. The matching is greedy and isolated. Precision and Recall are combined for calculating the F1 score.</p>
<p>The Scores we get are relative to BERT model performing on the dataset. We get a score of <code>0.94</code> pretty good ? too good to be true ? yes could be, but the validation dataset has only 1K samples.</p>
<p>The Model used to evaluate was <code>RoBERT</code></p>
<pre><code>roberta-large_L17_no-idf_version=0.3.9(hug_trans=4.8.2) P: 0.940923 R: 0.940774 F1: 0.940776
</code></pre>
<p>And here’s a sample run which shows the similarity matrix generated by BERTScore</p>
<p><img src="https://github.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/09_NLP_Evaluation/similarity_matrix.png?raw=true" alt="similarity matrix"></p>
<p>The BERT Score implementation was taken from <a href="https://github.com/Tiiiger/bert_score"><code>bert_score</code></a>, the source code of the scoring function can be found <a href="https://github.com/Tiiiger/bert_score/blob/master/bert_score/score.py">here</a></p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">def</span> <span class="token function">score</span><span class="token punctuation">(</span>
cands<span class="token punctuation">,</span>
refs<span class="token punctuation">,</span>
model_type<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
num_layers<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
verbose<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
idf<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
device<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
batch_size<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span>
nthreads<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span>
all_layers<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
lang<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
return_hash<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
rescale_with_baseline<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span>
baseline_path<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span>
<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""
BERTScore metric.
Args:
- :param: `cands` (list of str): candidate sentences
- :param: `refs` (list of str or list of list of str): reference sentences
- :param: `model_type` (str): bert specification, default using the suggested
model for the target langauge; has to specify at least one of
`model_type` or `lang`
- :param: `num_layers` (int): the layer of representation to use.
default using the number of layer tuned on WMT16 correlation data
- :param: `verbose` (bool): turn on intermediate status update
- :param: `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict
- :param: `device` (str): on which the contextual embedding model will be allocated on.
If this argument is None, the model lives on cuda:0 if cuda is available.
- :param: `nthreads` (int): number of threads
- :param: `batch_size` (int): bert score processing batch size
- :param: `lang` (str): language of the sentences; has to specify
at least one of `model_type` or `lang`. `lang` needs to be
specified when `rescale_with_baseline` is True.
- :param: `return_hash` (bool): return hash code of the setting
- :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
- :param: `baseline_path` (str): customized baseline file
Return:
- :param: `(P, R, F)`: each is of shape (N); N = number of input
candidate reference pairs. if returning hashcode, the
output will be ((P, R, F), hashcode). If a candidate have
multiple references, the returned score of this candidate is
the *best* score among all references.
"""</span>
<span class="token keyword">assert</span> <span class="token builtin">len</span><span class="token punctuation">(</span>cands<span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token builtin">len</span><span class="token punctuation">(</span>refs<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">"Different number of candidates and references"</span>
<span class="token keyword">assert</span> lang <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span> <span class="token operator">or</span> model_type <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">"Either lang or model_type should be specified"</span>
ref_group_boundaries <span class="token operator">=</span> <span class="token boolean">None</span>
<span class="token keyword">if</span> <span class="token operator">not</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>refs<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token builtin">str</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
ref_group_boundaries <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
ori_cands<span class="token punctuation">,</span> ori_refs <span class="token operator">=</span> cands<span class="token punctuation">,</span> refs
cands<span class="token punctuation">,</span> refs <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
count <span class="token operator">=</span> <span class="token number">0</span>
<span class="token keyword">for</span> cand<span class="token punctuation">,</span> ref_group <span class="token keyword">in</span> <span class="token builtin">zip</span><span class="token punctuation">(</span>ori_cands<span class="token punctuation">,</span> ori_refs<span class="token punctuation">)</span><span class="token punctuation">:</span>
cands <span class="token operator">+=</span> <span class="token punctuation">[</span>cand<span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token builtin">len</span><span class="token punctuation">(</span>ref_group<span class="token punctuation">)</span>
refs <span class="token operator">+=</span> ref_group
ref_group_boundaries<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token punctuation">(</span>count<span class="token punctuation">,</span> count <span class="token operator">+</span> <span class="token builtin">len</span><span class="token punctuation">(</span>ref_group<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
count <span class="token operator">+=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>ref_group<span class="token punctuation">)</span>
<span class="token keyword">if</span> rescale_with_baseline<span class="token punctuation">:</span>
<span class="token keyword">assert</span> lang <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">,</span> <span class="token string">"Need to specify Language when rescaling with baseline"</span>
<span class="token keyword">if</span> model_type <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
lang <span class="token operator">=</span> lang<span class="token punctuation">.</span>lower<span class="token punctuation">(</span><span class="token punctuation">)</span>
model_type <span class="token operator">=</span> lang2model<span class="token punctuation">[</span>lang<span class="token punctuation">]</span>
<span class="token keyword">if</span> num_layers <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
num_layers <span class="token operator">=</span> model2layers<span class="token punctuation">[</span>model_type<span class="token punctuation">]</span>
tokenizer <span class="token operator">=</span> get_tokenizer<span class="token punctuation">(</span>model_type<span class="token punctuation">)</span>
model <span class="token operator">=</span> get_model<span class="token punctuation">(</span>model_type<span class="token punctuation">,</span> num_layers<span class="token punctuation">,</span> all_layers<span class="token punctuation">)</span>
<span class="token keyword">if</span> device <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
device <span class="token operator">=</span> <span class="token string">"cuda"</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span> <span class="token string">"cpu"</span>
model<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span>
<span class="token keyword">if</span> <span class="token operator">not</span> idf<span class="token punctuation">:</span>
idf_dict <span class="token operator">=</span> defaultdict<span class="token punctuation">(</span><span class="token keyword">lambda</span><span class="token punctuation">:</span> <span class="token number">1.0</span><span class="token punctuation">)</span>
<span class="token comment"># set idf for [SEP] and [CLS] to 0</span>
idf_dict<span class="token punctuation">[</span>tokenizer<span class="token punctuation">.</span>sep_token_id<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">0</span>
idf_dict<span class="token punctuation">[</span>tokenizer<span class="token punctuation">.</span>cls_token_id<span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">0</span>
<span class="token keyword">elif</span> <span class="token builtin">isinstance</span><span class="token punctuation">(</span>idf<span class="token punctuation">,</span> <span class="token builtin">dict</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> verbose<span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"using predefined IDF dict..."</span><span class="token punctuation">)</span>
idf_dict <span class="token operator">=</span> idf
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> verbose<span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"preparing IDF dict..."</span><span class="token punctuation">)</span>
start <span class="token operator">=</span> time<span class="token punctuation">.</span>perf_counter<span class="token punctuation">(</span><span class="token punctuation">)</span>
idf_dict <span class="token operator">=</span> get_idf_dict<span class="token punctuation">(</span>refs<span class="token punctuation">,</span> tokenizer<span class="token punctuation">,</span> nthreads<span class="token operator">=</span>nthreads<span class="token punctuation">)</span>
<span class="token keyword">if</span> verbose<span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"done in {:.2f} seconds"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>time<span class="token punctuation">.</span>perf_counter<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">-</span> start<span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> verbose<span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"calculating scores..."</span><span class="token punctuation">)</span>
start <span class="token operator">=</span> time<span class="token punctuation">.</span>perf_counter<span class="token punctuation">(</span><span class="token punctuation">)</span>
all_preds <span class="token operator">=</span> bert_cos_score_idf<span class="token punctuation">(</span>
model<span class="token punctuation">,</span>
refs<span class="token punctuation">,</span>
cands<span class="token punctuation">,</span>
tokenizer<span class="token punctuation">,</span>
idf_dict<span class="token punctuation">,</span>
verbose<span class="token operator">=</span>verbose<span class="token punctuation">,</span>
device<span class="token operator">=</span>device<span class="token punctuation">,</span>
batch_size<span class="token operator">=</span>batch_size<span class="token punctuation">,</span>
all_layers<span class="token operator">=</span>all_layers<span class="token punctuation">,</span>
<span class="token punctuation">)</span><span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> ref_group_boundaries <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
max_preds <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
<span class="token keyword">for</span> beg<span class="token punctuation">,</span> end <span class="token keyword">in</span> ref_group_boundaries<span class="token punctuation">:</span>
max_preds<span class="token punctuation">.</span>append<span class="token punctuation">(</span>all_preds<span class="token punctuation">[</span>beg<span class="token punctuation">:</span>end<span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
all_preds <span class="token operator">=</span> torch<span class="token punctuation">.</span>stack<span class="token punctuation">(</span>max_preds<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span>
use_custom_baseline <span class="token operator">=</span> baseline_path <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span>
<span class="token keyword">if</span> rescale_with_baseline<span class="token punctuation">:</span>
<span class="token keyword">if</span> baseline_path <span class="token keyword">is</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
baseline_path <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>dirname<span class="token punctuation">(</span>__file__<span class="token punctuation">)</span><span class="token punctuation">,</span> f<span class="token string">"rescale_baseline/{lang}/{model_type}.tsv"</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>isfile<span class="token punctuation">(</span>baseline_path<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">if</span> <span class="token operator">not</span> all_layers<span class="token punctuation">:</span>
baselines <span class="token operator">=</span> torch<span class="token punctuation">.</span>from_numpy<span class="token punctuation">(</span>pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span>baseline_path<span class="token punctuation">)</span><span class="token punctuation">.</span>iloc<span class="token punctuation">[</span>num_layers<span class="token punctuation">]</span><span class="token punctuation">.</span>to_numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
baselines <span class="token operator">=</span> torch<span class="token punctuation">.</span>from_numpy<span class="token punctuation">(</span>pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span>baseline_path<span class="token punctuation">)</span><span class="token punctuation">.</span>to_numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">.</span>unsqueeze<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
all_preds <span class="token operator">=</span> <span class="token punctuation">(</span>all_preds <span class="token operator">-</span> baselines<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> baselines<span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span>
f<span class="token string">"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}"</span><span class="token punctuation">,</span> <span class="token builtin">file</span><span class="token operator">=</span>sys<span class="token punctuation">.</span>stderr<span class="token punctuation">,</span>
<span class="token punctuation">)</span>
out <span class="token operator">=</span> all_preds<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> all_preds<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> all_preds<span class="token punctuation">[</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">.</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token comment"># P, R, F</span>
<span class="token keyword">if</span> verbose<span class="token punctuation">:</span>
time_diff <span class="token operator">=</span> time<span class="token punctuation">.</span>perf_counter<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">-</span> start
<span class="token keyword">print</span><span class="token punctuation">(</span>f<span class="token string">"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec"</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> return_hash<span class="token punctuation">:</span>
<span class="token keyword">return</span> <span class="token builtin">tuple</span><span class="token punctuation">(</span>
<span class="token punctuation">[</span>
out<span class="token punctuation">,</span>
get_hash<span class="token punctuation">(</span>model_type<span class="token punctuation">,</span> num_layers<span class="token punctuation">,</span> idf<span class="token punctuation">,</span> rescale_with_baseline<span class="token punctuation">,</span> use_custom_baseline<span class="token operator">=</span>use_custom_baseline<span class="token punctuation">,</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token punctuation">]</span>
<span class="token punctuation">)</span>
<span class="token keyword">return</span> out
</code></pre>
<hr>
<p align="center">
<iframe src="https://giphy.com/embed/3nbxypT20Ulmo" width="480" height="355" class="giphy-embed" allowfullscreen=""></iframe></p><p><a href="https://giphy.com/gifs/coffee-morning-3nbxypT20Ulmo"></a></p>
<p align="center"><a href="https://open.spotify.com/track/1lIYP8fDGGnp91OMTUnwjV">🎶 Waqt Ki Baatein</a></p>
<hr>
<p align="center">
:wq satyajit
</p>
</div>
</body>
</html>