Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ita grounding analysis #64

Merged
merged 2 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/malco/analysis/ita_grounding_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from malco.post_process.post_process_results_format import read_raw_result_yaml
from pathlib import Path
import pandas as pd
import os
# Each row has
# c1 * c2 * c3 * c4 * c5 * c6 * c7 * c8
# PMID (str) * label/term (str) * * rank * ita_reply (bool) * correct_result OMIM ID * correct_result OMIM label * MONDO label (if applicable) * correct? 0/1 (in excel)

# Correct results
file = "/Users/leonardo/git/malco/in_ita_reply/correct_results.tsv"
answers = pd.read_csv(
file, sep="\t", header=None, names=["description", "term", "label"]
)

# Mapping each label to its correct term
cres = answers.set_index("label").to_dict()

# Just populate df with two for loops, then sort alfabetically
data = []

# load ita replies
ita_file = Path("/Users/leonardo/git/malco/out_itanoeng/raw_results/multilingual/it/results.yaml")
ita_result = read_raw_result_yaml(ita_file)

# extract input_text from yaml for ita, or extracted_object, terms
for ppkt_out in ita_result:
extracted_object = ppkt_out.get("extracted_object")
if extracted_object:
label = extracted_object.get("label").replace('_it-prompt', '_en-prompt')
terms = extracted_object.get("terms")
if terms:
num_terms = len(terms)
rank_list = [i + 1 for i in range(num_terms)]
for term, rank in zip(terms, rank_list):
data.append({"pubmedid": label, "term": term, "mondo_label": float('Nan'), "rank": rank, "ita_reply": True, "correct_omim_id": cres['term'][label],
"correct_omim_description": cres['description'][label]})


# load eng replies
eng_file = Path("/Users/leonardo/git/malco/out_itanoeng/raw_results/multilingual/it_w_en/results.yaml")
eng_result = read_raw_result_yaml(eng_file)

# extract named_entities, id and label from yaml for eng
# extract input_text from yaml for ita, or extracted_object, terms
for ppkt_out in eng_result:
extracted_object = ppkt_out.get("extracted_object")
if extracted_object:
label = extracted_object.get("label").replace('_it-prompt', '_en-prompt')
terms = extracted_object.get("terms")
if terms:
num_terms = len(terms)
rank_list = [i + 1 for i in range(num_terms)]
for term, rank in zip(terms, rank_list):
if term.startswith("MONDO"):
ne = ppkt_out.get("named_entities")
for entity in ne:
if entity.get('id')==term:
mlab = entity.get('label')
else:
mlab = float('Nan')

data.append({"pubmedid": label, "term": mlab, "mondo_label": term, "rank": rank, "ita_reply": False, "correct_omim_id": cres["term"][label],
"correct_omim_description": cres['description'][label]})

# Create DataFrame
column_names = [
"PMID",
"GPT Diagnosis",
"MONDO ID",
"rank",
"ita_reply",
"correct_OMIMid",
"correct_OMIMlabel",
]

df = pd.DataFrame(data)
df.columns = column_names
df.sort_values(by = ['PMID', 'ita_reply', 'rank'], inplace=True)
#df.to_excel(os.getcwd() + "ita_replies2curate.xlsx") # does not work, wrong path, not important
9 changes: 8 additions & 1 deletion src/malco/post_process/post_process_results_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def create_standardised_results(
)
# terms will now ONLY contain MONDO IDs OR 'N/A'.
# The latter should be dealt with downstream
terms = [i[1][0][0] for i in result] # MONDO_ID
new_terms = []
for i in result:
if i[1] == [("N/A", "No grounding found")]:
new_terms.append(i[0])
else:
new_terms.append(i[1][0][0])
terms = new_terms
#terms = [i[1][0][0] for i in result] # MONDO_ID
if terms:
# Note, the if allows for rerunning ppkts that failed due to connection issues
# We can have multiple identical ppkts/prompts in results.yaml
Expand Down
23 changes: 15 additions & 8 deletions src/malco/post_process/ranking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def compute_mrr_and_ranks(
"n10p",
"nf",
"num_cases",
"grounding_failed", # and no correct reply elsewhere in the differential
]
rank_df = pd.DataFrame(0, index=np.arange(len(results_files)), columns=header)

Expand Down Expand Up @@ -143,6 +144,7 @@ def compute_mrr_and_ranks(
)

df.dropna(subset=["correct_term"])

# Save full data frame
full_df_path = output_dir / results_files[i].split("/")[0]
full_df_filename = "full_df_results.tsv"
Expand All @@ -155,14 +157,17 @@ def compute_mrr_and_ranks(
# Calculate top<n> of each rank
rank_df.loc[i, comparing] = results_files[i].split("/")[0]

ppkts = df.groupby("label")[["rank", "is_correct"]]
ppkts = df.groupby("label")[["term", "rank", "is_correct"]]

# for each group
for ppkt in ppkts:
# is there a true? ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe
if not any(ppkt[1]["is_correct"]):
# no --> increase nf = "not found"
rank_df.loc[i, "nf"] += 1
if all(ppkt[1]["term"].str.startswith("MONDO")):
# no --> increase nf = "not found"
rank_df.loc[i, "nf"] += 1
else:
rank_df.loc[i, "grounding_failed"] += 1
else:
# yes --> what's it rank? It's <j>
jind = ppkt[1].index[ppkt[1]["is_correct"]]
Expand Down Expand Up @@ -204,10 +209,12 @@ def compute_mrr_and_ranks(
writer.writerow(results_files)
writer.writerow(mrr_scores)

# TODO this could be moved in an anaysis script with the plotting...
df = pd.read_csv(topn_file, delimiter="\t")
df["top1"] = (df["n1"]) / df["num_cases"]
df["top3"] = (df["n1"] + df["n2"] + df["n3"]) / df["num_cases"]
df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"]) / df["num_cases"]
valid_cases = df["num_cases"] - df["grounding_failed"]
df["top1"] = (df["n1"]) / valid_cases
df["top3"] = (df["n1"] + df["n2"] + df["n3"]) / valid_cases
df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"]) / valid_cases
df["top10"] = (
df["n1"]
+ df["n2"]
Expand All @@ -219,8 +226,8 @@ def compute_mrr_and_ranks(
+ df["n8"]
+ df["n9"]
+ df["n10"]
) / df["num_cases"]
df["not_found"] = (df["nf"]) / df["num_cases"]
) / valid_cases
df["not_found"] = (df["nf"]) / valid_cases

df_aggr = pd.DataFrame()
df_aggr = pd.melt(
Expand Down