Skip to content

Commit

Permalink
wordnet query expansion works pretty well
Browse files Browse the repository at this point in the history
  • Loading branch information
mc-cat-tty committed Jan 30, 2024
1 parent 91f74d2 commit 7e4e567
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 16 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
8 changes: 6 additions & 2 deletions placerank/query_expansion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from functools import cache
from huggingface_hub import snapshot_download
from typing import List
import nltk

def setup(repo_id: str, cache_dir: str):
snapshot_download(repo_id = repo_id, cache_dir = cache_dir)
def setup(repo_ids: List[str], cache_dir: str):
for id in repo_ids:
snapshot_download(repo_id = id, cache_dir = cache_dir)
nltk.download("wordnet")
39 changes: 27 additions & 12 deletions query_expansion.ipynb → query_expansion_bert.ipynb
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BERT-based Query Expansion Playground"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -18,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -45,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -58,12 +65,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"With the following line we are getting all the candidates to the masked word proposed by BERT. Each substitute has a confidence level associated with the token."
"With the following line we are getting all the candidates, to the masked word, proposed by BERT. Each substitute has a confidence level associated with the token."
]
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -78,7 +85,7 @@
"original_sentence_tokens = tokenizer.tokenize(original_sentence)\n",
"\n",
"masked_sentence = mask_token(original_sentence_tokens, 2)\n",
"candidates = unmasker(masked_sentence, top_k = 20)"
"candidates = unmasker(masked_sentence, top_k = 50)"
]
},
{
Expand All @@ -90,14 +97,23 @@
},
{
"cell_type": "code",
"execution_count": 199,
"execution_count": 14,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['room', 'house', 'residence', 'land', 'hall']\n"
]
}
],
"source": [
"from transformers import BertModel\n",
"import torch\n",
"\n",
"encoder = BertModel.from_pretrained('bert-large-uncased-whole-word-masking', output_hidden_states = True, cache_dir = 'hf_cache')\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir = 'hf_cache')\n",
"encoder = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True, cache_dir = 'hf_cache')\n",
"\n",
"def get_meaned_embeddings(sentence: str):\n",
" tokens = tokenizer.tokenize(sentence)\n",
Expand All @@ -112,7 +128,6 @@
"\n",
"cos_sim = torch.nn.CosineSimilarity(dim = 0)\n",
"\n",
"from functools import partial\n",
"from operator import itemgetter\n",
"import pydash\n",
"\n",
Expand All @@ -130,7 +145,7 @@
"THRESHOLD = 0.8\n",
"SLICE = 5\n",
"\n",
"synonyms = (\n",
"expansions = (\n",
" pydash.chain(candidates)\n",
" .map(itemgetter('token_str'))\n",
" .zip(similarities)\n",
Expand All @@ -141,7 +156,7 @@
" .value()\n",
")\n",
"\n",
"print(synonyms)"
"print(expansions)"
]
}
],
Expand Down
216 changes: 216 additions & 0 deletions query_expansion_wordnet.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# WordNet-based Query Expansion Playground"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The goal of WN-based query expansion is the same as BERT-based query expansion; furthermore, the strategy is almost the same, except for how similar tokens are generated.\n",
"In this case the candidate tokens are selected from the set of synonyms of the word that has to be expanded."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pydash\n",
"from nltk.wsd import lesk\n",
"from nltk.corpus import wordnet as wn, wordnet_ic\n",
"\n",
"def naive_wsd(list_of_synsets, term_dis):\n",
" \"\"\"\n",
" list_of_synsets list of lists containig synsets of each word\n",
" term_dis term to be disambiguated\n",
" \"\"\"\n",
" brown_ic = wordnet_ic.ic(\"ic-brown.dat\")\n",
" # Lower res_similarity -> low probability of associated concepts\n",
"\n",
" sense_confidence = float('-inf')\n",
" disambiguated_sense = None\n",
"\n",
" for sense_dis in term_dis:\n",
" confidence = 0\n",
" for term_other in list_of_synsets:\n",
" if term_dis != term_other:\n",
" confidence += max([sense_dis.res_similarity(sense_other, brown_ic) for sense_other in term_other])\n",
" if confidence > sense_confidence:\n",
" disambiguated_sense = sense_dis\n",
" sense_confidence = confidence\n",
" \n",
" return disambiguated_sense, confidence"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below, some experiments have been made in order to understand which is the best way to get the candidates.\n",
"An empirical test showed that Lesk WSD underperforms against the naive strategy. The idea from that point would have been to take synset's hyponyms and hyperonyms, but the overhead caused by WSD and POS tagging (for a more accurate WSD) is not worth the effort.\n",
"Instead, taking the synonyms of a word seems to be a much more consistent method."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lesk_synset=Synset('room.n.04')\n",
"naive_synset=(Synset('room.n.01'), 0.5962292078977726)\n",
"[['room'], ['room', 'way', 'elbow_room'], ['room'], ['room'], ['board', 'room']]\n",
"Synset('room.n.01') an area within a building enclosed by walls and floor and ceiling\n",
"Synset('room.n.02') space for movement\n",
"Synset('room.n.03') opportunity for\n",
"Synset('room.n.04') the people who are present in a room\n",
"Synset('board.v.02') live and take one's meals at or in\n"
]
}
],
"source": [
"TOKEN_ID = 2\n",
"\n",
"original_sentence_tokens = 'modern shared room near Harvard'.split()\n",
"\n",
"tmp = original_sentence_tokens[:]\n",
"tmp[TOKEN_ID] = '{}'\n",
"original_sentence_fmt = ' '.join(tmp)\n",
"token = original_sentence_tokens[TOKEN_ID]\n",
"\n",
"lesk_synset = lesk(original_sentence_tokens, token)\n",
"print(f'{lesk_synset=}')\n",
"\n",
"nouns_synsets = (\n",
" pydash.chain(original_sentence_tokens)\n",
" .map(lambda n: wn.morphy(n, wn.NOUN))\n",
" .filter(lambda n: n is not None)\n",
" .map(lambda n: wn.synsets(n, wn.NOUN))\n",
" .value()\n",
" )\n",
"\n",
"naive_synset = naive_wsd(nouns_synsets, nouns_synsets[1])\n",
"print(f'{naive_synset=}')\n",
"print([s.lemma_names() for s in wn.synsets(token)])\n",
"\n",
"for s in wn.synsets(token):\n",
" print(s, s.definition())"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertTokenizer\n",
"from transformers import BertModel\n",
"import torch\n",
"\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir = 'hf_cache')\n",
"encoder = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True, cache_dir = 'hf_cache')\n",
"\n",
"def get_meaned_embeddings(sentence: str):\n",
" tokens = tokenizer.tokenize(sentence)\n",
" input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
"\n",
" input_ids = torch.tensor(input_ids).unsqueeze(0)\n",
" with torch.no_grad():\n",
" outputs = encoder(input_ids)\n",
" embedding = outputs.last_hidden_state[0]\n",
"\n",
" return embedding.mean(dim = 0)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['board', 'elbow_room', 'room', 'way']\n",
"[tensor(0.9377), tensor(0.9285), tensor(1.), tensor(0.9333)]\n",
"['room', 'board', 'way', 'elbow_room']\n"
]
}
],
"source": [
"from operator import itemgetter\n",
"\n",
"candidates = (\n",
" pydash.chain([s.lemma_names() for s in wn.synsets(token)])\n",
" .flatten_deep()\n",
" .sorted_uniq()\n",
" .value()\n",
")\n",
"\n",
"print(candidates)\n",
"\n",
"original_sentence_embedding = get_meaned_embeddings(original_sentence_fmt.format(token))\n",
"cos_sim = torch.nn.CosineSimilarity(dim = 0)\n",
"\n",
"similarities = (\n",
" pydash.chain(candidates)\n",
" .map(lambda c: c.replace('_', ' '))\n",
" .map(lambda c: original_sentence_fmt.format(c))\n",
" .map(get_meaned_embeddings)\n",
" .map(lambda x: cos_sim(x, original_sentence_embedding))\n",
" .value()\n",
")\n",
"print(similarities)\n",
"\n",
"THRESHOLD = 0.8\n",
"SLICE = 5\n",
"\n",
"expansions = (\n",
" pydash.chain(candidates)\n",
" .zip(similarities)\n",
" .filter(lambda t: t[1] > THRESHOLD)\n",
" .sort(key = itemgetter(1), reverse = True)\n",
" .map(itemgetter(0))\n",
" .take(SLICE)\n",
" .value()\n",
")\n",
"\n",
"print(expansions)"
]
}
],
"metadata": {
"interpreter": {
"hash": "cc551502c4709d65c35225dab174e15b6f215f4b9bca0aec7618bac23f51ade6"
},
"kernelspec": {
"display_name": "Python 3.11.6 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
DATASET_URL = 'http://data.insideairbnb.com/united-states/ny/new-york-city/2024-01-05/data/listings.csv.gz'
INDEX_DIR = 'index/'
DATASET_CACHE_FILE = "datasets/listings.csv"
HF_MODEL = 'bert-large-uncased-whole-word-masking'
HF_MODEL_MASKING = 'bert-large-uncased-whole-word-masking'
HF_MODEL_ENCODING = 'bert-base-uncased'
HF_CACHE = 'hf_cache'

def main():
preprocessing.setup()
dataset.populate_index(INDEX_DIR, DATASET_CACHE_FILE, DATASET_URL)
query_expansion.setup(HF_MODEL, HF_CACHE)
query_expansion.setup([HF_MODEL_MASKING, HF_MODEL_ENCODING], HF_CACHE)

if __name__ == "__main__":
main()

0 comments on commit 7e4e567

Please sign in to comment.