-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wordnet query expansion works pretty well
- Loading branch information
1 parent
91f74d2
commit 7e4e567
Showing
8 changed files
with
252 additions
and
16 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
from functools import cache | ||
from huggingface_hub import snapshot_download | ||
from typing import List | ||
import nltk | ||
|
||
def setup(repo_id: str, cache_dir: str): | ||
snapshot_download(repo_id = repo_id, cache_dir = cache_dir) | ||
def setup(repo_ids: List[str], cache_dir: str): | ||
for id in repo_ids: | ||
snapshot_download(repo_id = id, cache_dir = cache_dir) | ||
nltk.download("wordnet") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# WordNet-based Query Expansion Playground" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The goal of WN-based query expansion is the same as BERT-based query expansion; furthermore, the strategy is almost the same, except for how similar tokens are generated.\n", | ||
"In this case the candidate tokens are selected from the set of synonyms of the word that has to be expanded." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pydash\n", | ||
"from nltk.wsd import lesk\n", | ||
"from nltk.corpus import wordnet as wn, wordnet_ic\n", | ||
"\n", | ||
"def naive_wsd(list_of_synsets, term_dis):\n", | ||
" \"\"\"\n", | ||
" list_of_synsets list of lists containig synsets of each word\n", | ||
" term_dis term to be disambiguated\n", | ||
" \"\"\"\n", | ||
" brown_ic = wordnet_ic.ic(\"ic-brown.dat\")\n", | ||
" # Lower res_similarity -> low probability of associated concepts\n", | ||
"\n", | ||
" sense_confidence = float('-inf')\n", | ||
" disambiguated_sense = None\n", | ||
"\n", | ||
" for sense_dis in term_dis:\n", | ||
" confidence = 0\n", | ||
" for term_other in list_of_synsets:\n", | ||
" if term_dis != term_other:\n", | ||
" confidence += max([sense_dis.res_similarity(sense_other, brown_ic) for sense_other in term_other])\n", | ||
" if confidence > sense_confidence:\n", | ||
" disambiguated_sense = sense_dis\n", | ||
" sense_confidence = confidence\n", | ||
" \n", | ||
" return disambiguated_sense, confidence" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Below, some experiments have been made in order to understand which is the best way to get the candidates.\n", | ||
"An empirical test showed that Lesk WSD underperforms against the naive strategy. The idea from that point would have been to take synset's hyponyms and hyperonyms, but the overhead caused by WSD and POS tagging (for a more accurate WSD) is not worth the effort.\n", | ||
"Instead, taking the synonyms of a word seems to be a much more consistent method." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"lesk_synset=Synset('room.n.04')\n", | ||
"naive_synset=(Synset('room.n.01'), 0.5962292078977726)\n", | ||
"[['room'], ['room', 'way', 'elbow_room'], ['room'], ['room'], ['board', 'room']]\n", | ||
"Synset('room.n.01') an area within a building enclosed by walls and floor and ceiling\n", | ||
"Synset('room.n.02') space for movement\n", | ||
"Synset('room.n.03') opportunity for\n", | ||
"Synset('room.n.04') the people who are present in a room\n", | ||
"Synset('board.v.02') live and take one's meals at or in\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"TOKEN_ID = 2\n", | ||
"\n", | ||
"original_sentence_tokens = 'modern shared room near Harvard'.split()\n", | ||
"\n", | ||
"tmp = original_sentence_tokens[:]\n", | ||
"tmp[TOKEN_ID] = '{}'\n", | ||
"original_sentence_fmt = ' '.join(tmp)\n", | ||
"token = original_sentence_tokens[TOKEN_ID]\n", | ||
"\n", | ||
"lesk_synset = lesk(original_sentence_tokens, token)\n", | ||
"print(f'{lesk_synset=}')\n", | ||
"\n", | ||
"nouns_synsets = (\n", | ||
" pydash.chain(original_sentence_tokens)\n", | ||
" .map(lambda n: wn.morphy(n, wn.NOUN))\n", | ||
" .filter(lambda n: n is not None)\n", | ||
" .map(lambda n: wn.synsets(n, wn.NOUN))\n", | ||
" .value()\n", | ||
" )\n", | ||
"\n", | ||
"naive_synset = naive_wsd(nouns_synsets, nouns_synsets[1])\n", | ||
"print(f'{naive_synset=}')\n", | ||
"print([s.lemma_names() for s in wn.synsets(token)])\n", | ||
"\n", | ||
"for s in wn.synsets(token):\n", | ||
" print(s, s.definition())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 42, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import BertTokenizer\n", | ||
"from transformers import BertModel\n", | ||
"import torch\n", | ||
"\n", | ||
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir = 'hf_cache')\n", | ||
"encoder = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True, cache_dir = 'hf_cache')\n", | ||
"\n", | ||
"def get_meaned_embeddings(sentence: str):\n", | ||
" tokens = tokenizer.tokenize(sentence)\n", | ||
" input_ids = tokenizer.convert_tokens_to_ids(tokens)\n", | ||
"\n", | ||
" input_ids = torch.tensor(input_ids).unsqueeze(0)\n", | ||
" with torch.no_grad():\n", | ||
" outputs = encoder(input_ids)\n", | ||
" embedding = outputs.last_hidden_state[0]\n", | ||
"\n", | ||
" return embedding.mean(dim = 0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 43, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['board', 'elbow_room', 'room', 'way']\n", | ||
"[tensor(0.9377), tensor(0.9285), tensor(1.), tensor(0.9333)]\n", | ||
"['room', 'board', 'way', 'elbow_room']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from operator import itemgetter\n", | ||
"\n", | ||
"candidates = (\n", | ||
" pydash.chain([s.lemma_names() for s in wn.synsets(token)])\n", | ||
" .flatten_deep()\n", | ||
" .sorted_uniq()\n", | ||
" .value()\n", | ||
")\n", | ||
"\n", | ||
"print(candidates)\n", | ||
"\n", | ||
"original_sentence_embedding = get_meaned_embeddings(original_sentence_fmt.format(token))\n", | ||
"cos_sim = torch.nn.CosineSimilarity(dim = 0)\n", | ||
"\n", | ||
"similarities = (\n", | ||
" pydash.chain(candidates)\n", | ||
" .map(lambda c: c.replace('_', ' '))\n", | ||
" .map(lambda c: original_sentence_fmt.format(c))\n", | ||
" .map(get_meaned_embeddings)\n", | ||
" .map(lambda x: cos_sim(x, original_sentence_embedding))\n", | ||
" .value()\n", | ||
")\n", | ||
"print(similarities)\n", | ||
"\n", | ||
"THRESHOLD = 0.8\n", | ||
"SLICE = 5\n", | ||
"\n", | ||
"expansions = (\n", | ||
" pydash.chain(candidates)\n", | ||
" .zip(similarities)\n", | ||
" .filter(lambda t: t[1] > THRESHOLD)\n", | ||
" .sort(key = itemgetter(1), reverse = True)\n", | ||
" .map(itemgetter(0))\n", | ||
" .take(SLICE)\n", | ||
" .value()\n", | ||
")\n", | ||
"\n", | ||
"print(expansions)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "cc551502c4709d65c35225dab174e15b6f215f4b9bca0aec7618bac23f51ade6" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.11.6 ('venv': venv)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.6" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters