Skip to content

Commit

Permalink
query expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
mc-cat-tty committed Jan 29, 2024
1 parent 95ad476 commit 1b2d198
Showing 1 changed file with 33 additions and 29 deletions.
62 changes: 33 additions & 29 deletions query_expansion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -90,58 +90,62 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 195,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(0.9997) modern shared is near harvard.\n",
"tensor(0.9992) modern shared campus near harvard.\n",
"tensor(0.9712) modern shared houses near harvard.\n",
"tensor(0.9996) modern shared property near harvard.\n",
"tensor(0.9993) modern shared buildings near harvard.\n",
"tensor(0.9992) modern shared building near harvard.\n",
"tensor(0.9996) modern shared house near harvard.\n",
"tensor(0.9990) modern shared residence near harvard.\n",
"tensor(0.9984) modern shared, near harvard.\n",
"tensor(0.9996) modern shared was near harvard.\n",
"tensor(0.9990) modern shared it near harvard.\n",
"tensor(0.9704) modern shared housing near harvard.\n",
"tensor(0.9995) modern shared located near harvard.\n",
"tensor(0.9802) modern shared space near harvard.\n",
"tensor(0.9996) modern shared apartments near harvard.\n",
"tensor(0.9996) modern shared lived near harvard.\n",
"tensor(0.9994) modern shared school near harvard.\n",
"tensor(0.9997) modern shared history near harvard.\n",
"tensor(0.9999) modern shared lives near harvard.\n",
"tensor(0.9997) modern shared rooms near harvard.\n"
"['house', 'property', 'is', 'building', 'residence']\n"
]
}
],
"source": [
"from transformers import BertModel\n",
"import torch\n",
"\n",
"embeddings_model = BertModel.from_pretrained('bert-large-uncased-whole-word-masking', cache_dir = 'hf_cache')\n",
"encoder = BertModel.from_pretrained('bert-large-uncased-whole-word-masking', output_hidden_states = True, cache_dir = 'hf_cache')\n",
"\n",
"def get_meaned_embeddings(sentence: str):\n",
" tokens = tokenizer.tokenize(sentence)\n",
" input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
"\n",
" input_ids = torch.tensor(input_ids).unsqueeze(0)\n",
" with torch.no_grad():\n",
" outputs = embeddings_model(input_ids)\n",
" embeddings = outputs.last_hidden_state[0]\n",
" outputs = encoder(input_ids)\n",
" embedding = outputs.last_hidden_state[0]\n",
"\n",
" return embeddings.mean(1)\n",
" return embedding.mean(dim = 0)\n",
"\n",
"cos_sim = torch.nn.CosineSimilarity(dim = 0)\n",
"\n",
"for candidate in candidates:\n",
" sim = cos_sim(get_meaned_embeddings(candidate['sequence']), get_meaned_embeddings(original_sentence))\n",
" print(sim, candidate['sequence'])"
"from functools import partial\n",
"from operator import itemgetter\n",
"import pydash\n",
"\n",
"original_sentence_embedding = get_meaned_embeddings(original_sentence)\n",
"\n",
"\n",
"similarities = (\n",
" pydash.chain(candidates)\n",
" .map(itemgetter('sequence')) # Get complete sentence\n",
" .map(get_meaned_embeddings) # Get context vectors\n",
" .map(lambda x: cos_sim(x, original_sentence_embedding)) # Compute the similarity\n",
" .value()\n",
")\n",
"\n",
"synonyms = (\n",
" pydash.chain(candidates)\n",
" .map(itemgetter('token_str'))\n",
" .zip(similarities)\n",
" .sort(key = itemgetter(1), reverse = True)\n",
" .map(itemgetter(0))\n",
" .take(5)\n",
" .value()\n",
")\n",
"\n",
"print(synonyms)"
]
}
],
Expand Down

0 comments on commit 1b2d198

Please sign in to comment.