-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Analyze precision and recall for initial eval dataset (#150)
- Loading branch information
1 parent
2b5ae2e
commit 44253d7
Showing
4 changed files
with
806 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"id": "7e86fe3f-cc7f-4f67-8f7e-be9ea7f4a121", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "542e5fe2614042e2b10e88bd496a6326", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Text(value='question_answer_pairs.csv', description='File Name:')" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "da60c554b0654e7fab1c0d5eaae47ae8", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"BoundedIntText(value=10, description='K:', max=50, min=1)" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "5bf2e7389ac146d99de00ea9c2f26213", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Button(description='Compute metrics', style=ButtonStyle())" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "93fd9340779f49f48ed6350475eabe77", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Output()" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"import ipywidgets as widgets\n", | ||
"from IPython.display import display\n", | ||
"import csv\n", | ||
"from hashlib import md5\n", | ||
"\n", | ||
"from src.retrieve import retrieve_with_scores\n", | ||
"\n", | ||
"file_name_widget = widgets.Text(\n", | ||
" value='question_answer_pairs.csv',\n", | ||
" description='File Name:',\n", | ||
")\n", | ||
"\n", | ||
"load_button = widgets.Button(description=\"Compute metrics\")\n", | ||
"\n", | ||
"k_widget = widgets.BoundedIntText(\n", | ||
" value=10,\n", | ||
" min=1,\n", | ||
" max=50,\n", | ||
" step=1,\n", | ||
" description='K:',\n", | ||
" disabled=False\n", | ||
")\n", | ||
"\n", | ||
"output = widgets.Output()\n", | ||
"\n", | ||
"def load_csv(file_name):\n", | ||
" try:\n", | ||
" with open(file_name, mode='r', encoding='utf-8') as csv_file:\n", | ||
" reader = csv.DictReader(csv_file)\n", | ||
" data = [row for row in reader]\n", | ||
" return data\n", | ||
" except FileNotFoundError:\n", | ||
" return f\"Error: File '{file_name}' not found.\"\n", | ||
" except Exception as e:\n", | ||
" return f\"Error: {e}\"\n", | ||
"\n", | ||
"def on_button_click(b):\n", | ||
" with output:\n", | ||
" output.clear_output()\n", | ||
" file_name = file_name_widget.value\n", | ||
" questions = load_csv(file_name)\n", | ||
" precision, recall = compute_precision_recall(questions, k_widget.value)\n", | ||
" print(\"Precision: \", precision)\n", | ||
" print(\"Recall: \", recall)\n", | ||
"\n", | ||
"\n", | ||
"def compute_precision_recall(questions, k):\n", | ||
" precision = 0\n", | ||
" recall = 0\n", | ||
"\n", | ||
" for question in questions:\n", | ||
" results = retrieve_with_scores(question['question'], k, -1)\n", | ||
" content_hashes = [md5(r.chunk.content.encode('utf-8'), usedforsecurity=False).hexdigest() for r in results]\n", | ||
"\n", | ||
" # This calculation assumes there is exactly one expected chunk\n", | ||
" # to retrieve\n", | ||
" if question['content_hash'] in content_hashes:\n", | ||
" recall += 1\n", | ||
" precision += 1/k\n", | ||
"\n", | ||
" precision /= len(questions)\n", | ||
" recall /= len(questions)\n", | ||
" return precision, recall\n", | ||
"\n", | ||
"\n", | ||
"load_button.on_click(on_button_click)\n", | ||
"display(file_name_widget, k_widget, load_button, output)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "330e8663-505a-4838-a635-e7c37abb5d86", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.