diff --git a/BERTembeddings.ipynb b/notebooks/BERTembeddings.ipynb similarity index 100% rename from BERTembeddings.ipynb rename to notebooks/BERTembeddings.ipynb diff --git a/finetuneBERT.ipynb b/notebooks/finetuneBERT.ipynb similarity index 100% rename from finetuneBERT.ipynb rename to notebooks/finetuneBERT.ipynb diff --git a/goEmotions.ipynb b/notebooks/goEmotions.ipynb similarity index 100% rename from goEmotions.ipynb rename to notebooks/goEmotions.ipynb diff --git a/roberta-finetuned-50split-dataset.ipynb b/notebooks/roberta-finetuned-50split-dataset.ipynb similarity index 100% rename from roberta-finetuned-50split-dataset.ipynb rename to notebooks/roberta-finetuned-50split-dataset.ipynb diff --git a/placerank/query_expansion.py b/placerank/query_expansion.py index a8c7073..89dad31 100644 --- a/placerank/query_expansion.py +++ b/placerank/query_expansion.py @@ -1,5 +1,9 @@ from functools import cache from huggingface_hub import snapshot_download +from typing import List +import nltk -def setup(repo_id: str, cache_dir: str): - snapshot_download(repo_id = repo_id, cache_dir = cache_dir) +def setup(repo_ids: List[str], cache_dir: str): + for id in repo_ids: + snapshot_download(repo_id = id, cache_dir = cache_dir) + nltk.download("wordnet") diff --git a/query_expansion.ipynb b/query_expansion_bert.ipynb similarity index 85% rename from query_expansion.ipynb rename to query_expansion_bert.ipynb index f0b4c27..8670db4 100644 --- a/query_expansion.ipynb +++ b/query_expansion_bert.ipynb @@ -1,8 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BERT-based Query Expansion Playground" + ] + }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -18,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -45,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -58,12 +65,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With the following line we are getting all the candidates to the masked word proposed by BERT. Each substitute has a confidence level associated with the token." + "With the following line we are getting all the candidates, to the masked word, proposed by BERT. Each substitute has a confidence level associated with the token." ] }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -78,7 +85,7 @@ "original_sentence_tokens = tokenizer.tokenize(original_sentence)\n", "\n", "masked_sentence = mask_token(original_sentence_tokens, 2)\n", - "candidates = unmasker(masked_sentence, top_k = 20)" + "candidates = unmasker(masked_sentence, top_k = 50)" ] }, { @@ -90,14 +97,23 @@ }, { "cell_type": "code", - "execution_count": 199, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['room', 'house', 'residence', 'land', 'hall']\n" + ] + } + ], "source": [ "from transformers import BertModel\n", "import torch\n", "\n", - "encoder = BertModel.from_pretrained('bert-large-uncased-whole-word-masking', output_hidden_states = True, cache_dir = 'hf_cache')\n", + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir = 'hf_cache')\n", + "encoder = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True, cache_dir = 'hf_cache')\n", "\n", "def get_meaned_embeddings(sentence: str):\n", " tokens = tokenizer.tokenize(sentence)\n", @@ -112,7 +128,6 @@ "\n", "cos_sim = torch.nn.CosineSimilarity(dim = 0)\n", "\n", - "from functools import partial\n", "from operator import itemgetter\n", "import pydash\n", "\n", @@ -130,7 +145,7 @@ "THRESHOLD = 0.8\n", "SLICE = 5\n", "\n", - "synonyms = (\n", + "expansions = (\n", " pydash.chain(candidates)\n", " .map(itemgetter('token_str'))\n", " .zip(similarities)\n", @@ -141,7 +156,7 @@ " .value()\n", ")\n", "\n", - "print(synonyms)" + "print(expansions)" ] } ], diff --git a/query_expansion_wordnet.ipynb b/query_expansion_wordnet.ipynb new file mode 100644 index 0000000..9c12e66 --- /dev/null +++ b/query_expansion_wordnet.ipynb @@ -0,0 +1,216 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# WordNet-based Query Expansion Playground" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The goal of WN-based query expansion is the same as BERT-based query expansion; furthermore, the strategy is almost the same, except for how similar tokens are generated.\n", + "In this case the candidate tokens are selected from the set of synonyms of the word that has to be expanded." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import pydash\n", + "from nltk.wsd import lesk\n", + "from nltk.corpus import wordnet as wn, wordnet_ic\n", + "\n", + "def naive_wsd(list_of_synsets, term_dis):\n", + " \"\"\"\n", + " list_of_synsets list of lists containig synsets of each word\n", + " term_dis term to be disambiguated\n", + " \"\"\"\n", + " brown_ic = wordnet_ic.ic(\"ic-brown.dat\")\n", + " # Lower res_similarity -> low probability of associated concepts\n", + "\n", + " sense_confidence = float('-inf')\n", + " disambiguated_sense = None\n", + "\n", + " for sense_dis in term_dis:\n", + " confidence = 0\n", + " for term_other in list_of_synsets:\n", + " if term_dis != term_other:\n", + " confidence += max([sense_dis.res_similarity(sense_other, brown_ic) for sense_other in term_other])\n", + " if confidence > sense_confidence:\n", + " disambiguated_sense = sense_dis\n", + " sense_confidence = confidence\n", + " \n", + " return disambiguated_sense, confidence" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below, some experiments have been made in order to understand which is the best way to get the candidates.\n", + "An empirical test showed that Lesk WSD underperforms against the naive strategy. The idea from that point would have been to take synset's hyponyms and hyperonyms, but the overhead caused by WSD and POS tagging (for a more accurate WSD) is not worth the effort.\n", + "Instead, taking the synonyms of a word seems to be a much more consistent method." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lesk_synset=Synset('room.n.04')\n", + "naive_synset=(Synset('room.n.01'), 0.5962292078977726)\n", + "[['room'], ['room', 'way', 'elbow_room'], ['room'], ['room'], ['board', 'room']]\n", + "Synset('room.n.01') an area within a building enclosed by walls and floor and ceiling\n", + "Synset('room.n.02') space for movement\n", + "Synset('room.n.03') opportunity for\n", + "Synset('room.n.04') the people who are present in a room\n", + "Synset('board.v.02') live and take one's meals at or in\n" + ] + } + ], + "source": [ + "TOKEN_ID = 2\n", + "\n", + "original_sentence_tokens = 'modern shared room near Harvard'.split()\n", + "\n", + "tmp = original_sentence_tokens[:]\n", + "tmp[TOKEN_ID] = '{}'\n", + "original_sentence_fmt = ' '.join(tmp)\n", + "token = original_sentence_tokens[TOKEN_ID]\n", + "\n", + "lesk_synset = lesk(original_sentence_tokens, token)\n", + "print(f'{lesk_synset=}')\n", + "\n", + "nouns_synsets = (\n", + " pydash.chain(original_sentence_tokens)\n", + " .map(lambda n: wn.morphy(n, wn.NOUN))\n", + " .filter(lambda n: n is not None)\n", + " .map(lambda n: wn.synsets(n, wn.NOUN))\n", + " .value()\n", + " )\n", + "\n", + "naive_synset = naive_wsd(nouns_synsets, nouns_synsets[1])\n", + "print(f'{naive_synset=}')\n", + "print([s.lemma_names() for s in wn.synsets(token)])\n", + "\n", + "for s in wn.synsets(token):\n", + " print(s, s.definition())" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import BertTokenizer\n", + "from transformers import BertModel\n", + "import torch\n", + "\n", + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir = 'hf_cache')\n", + "encoder = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True, cache_dir = 'hf_cache')\n", + "\n", + "def get_meaned_embeddings(sentence: str):\n", + " tokens = tokenizer.tokenize(sentence)\n", + " input_ids = tokenizer.convert_tokens_to_ids(tokens)\n", + "\n", + " input_ids = torch.tensor(input_ids).unsqueeze(0)\n", + " with torch.no_grad():\n", + " outputs = encoder(input_ids)\n", + " embedding = outputs.last_hidden_state[0]\n", + "\n", + " return embedding.mean(dim = 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['board', 'elbow_room', 'room', 'way']\n", + "[tensor(0.9377), tensor(0.9285), tensor(1.), tensor(0.9333)]\n", + "['room', 'board', 'way', 'elbow_room']\n" + ] + } + ], + "source": [ + "from operator import itemgetter\n", + "\n", + "candidates = (\n", + " pydash.chain([s.lemma_names() for s in wn.synsets(token)])\n", + " .flatten_deep()\n", + " .sorted_uniq()\n", + " .value()\n", + ")\n", + "\n", + "print(candidates)\n", + "\n", + "original_sentence_embedding = get_meaned_embeddings(original_sentence_fmt.format(token))\n", + "cos_sim = torch.nn.CosineSimilarity(dim = 0)\n", + "\n", + "similarities = (\n", + " pydash.chain(candidates)\n", + " .map(lambda c: c.replace('_', ' '))\n", + " .map(lambda c: original_sentence_fmt.format(c))\n", + " .map(get_meaned_embeddings)\n", + " .map(lambda x: cos_sim(x, original_sentence_embedding))\n", + " .value()\n", + ")\n", + "print(similarities)\n", + "\n", + "THRESHOLD = 0.8\n", + "SLICE = 5\n", + "\n", + "expansions = (\n", + " pydash.chain(candidates)\n", + " .zip(similarities)\n", + " .filter(lambda t: t[1] > THRESHOLD)\n", + " .sort(key = itemgetter(1), reverse = True)\n", + " .map(itemgetter(0))\n", + " .take(SLICE)\n", + " .value()\n", + ")\n", + "\n", + "print(expansions)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "cc551502c4709d65c35225dab174e15b6f215f4b9bca0aec7618bac23f51ade6" + }, + "kernelspec": { + "display_name": "Python 3.11.6 ('venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 3b9af55..09a5496 100644 --- a/setup.py +++ b/setup.py @@ -7,13 +7,14 @@ DATASET_URL = 'http://data.insideairbnb.com/united-states/ny/new-york-city/2024-01-05/data/listings.csv.gz' INDEX_DIR = 'index/' DATASET_CACHE_FILE = "datasets/listings.csv" -HF_MODEL = 'bert-large-uncased-whole-word-masking' +HF_MODEL_MASKING = 'bert-large-uncased-whole-word-masking' +HF_MODEL_ENCODING = 'bert-base-uncased' HF_CACHE = 'hf_cache' def main(): preprocessing.setup() dataset.populate_index(INDEX_DIR, DATASET_CACHE_FILE, DATASET_URL) - query_expansion.setup(HF_MODEL, HF_CACHE) + query_expansion.setup([HF_MODEL_MASKING, HF_MODEL_ENCODING], HF_CACHE) if __name__ == "__main__": main() \ No newline at end of file