-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from gomate-community/pipeline
features@add DenseRetriever
- Loading branch information
Showing
5 changed files
with
192 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
https://github.com/chen700564/RGB/blob/master/data/zh_refine.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
# https://deepnote.com/blog/semantic-search-using-faiss-and-mpnet | ||
@author: yanqiangmiffy | ||
@contact:[email protected] | ||
@license: Apache Licence | ||
@time: 2024/6/5 22:37 | ||
""" | ||
import pandas as pd | ||
from transformers import AutoTokenizer, AutoModel | ||
import torch | ||
import torch.nn.functional as F | ||
import faiss | ||
import numpy as np | ||
import os | ||
from tqdm import tqdm | ||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" | ||
class SemanticEmbedding: | ||
|
||
def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'): | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.model = AutoModel.from_pretrained(model_name) | ||
|
||
# Mean Pooling - Take attention mask into account for correct averaging | ||
def mean_pooling(self, model_output, attention_mask): | ||
token_embeddings = model_output[0] # First element of model_output contains all token embeddings | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
|
||
def get_embedding(self, sentences): | ||
# Tokenize sentences | ||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') | ||
with torch.no_grad(): | ||
model_output = self.model(**encoded_input) | ||
# Perform pooling | ||
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) | ||
|
||
# Normalize embeddings | ||
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | ||
return sentence_embeddings.detach().numpy() | ||
class FaissIdx: | ||
|
||
def __init__(self, model, dim=768): | ||
self.index = faiss.IndexFlatIP(dim) | ||
# Maintaining the document data | ||
self.doc_map = dict() | ||
self.model = model | ||
self.ctr = 0 | ||
|
||
def add_doc(self, document_text): | ||
self.index.add(self.model.get_embedding(document_text)) | ||
self.doc_map[self.ctr] = document_text # store the original document text | ||
self.ctr += 1 | ||
|
||
def search_doc(self, query, k=3): | ||
D, I = self.index.search(self.model.get_embedding(query), k) | ||
return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map] | ||
if __name__ == '__main__': | ||
model = SemanticEmbedding(r'I:\pretrained_models\bert\english\paraphrase-multilingual-mpnet-base-v2') | ||
a = model.get_embedding("我喜欢打篮球") | ||
print(a) | ||
print(a.shape) | ||
|
||
index = FaissIdx(model) | ||
index.add_doc("笔记本电脑") | ||
index.add_doc("医生的办公室") | ||
result=index.search_doc("个人电脑") | ||
print(result) | ||
|
||
|
||
# 加载测试文档 | ||
data=pd.read_json('../../data/zh_refine.json', lines=True)[:50] | ||
print(data) | ||
print(data.columns) | ||
|
||
for documents in tqdm(data['positive'],total=len(data)): | ||
for document in documents: | ||
index.add_doc(document) | ||
|
||
for documents in tqdm(data['negative'],total=len(data)): | ||
for document in documents: | ||
index.add_doc(document) | ||
|
||
result=index.search_doc("2022年特斯拉交付量") | ||
print(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@author: yanqiangmiffy | ||
@contact:[email protected] | ||
@license: Apache Licence | ||
@time: 2024/6/5 22:36 | ||
""" | ||
import json | ||
|
||
from gomate.modules.retrieval.embedding import SBertEmbeddingModel | ||
from gomate.modules.retrieval.faiss_retriever import FaissRetriever, FaissRetrieverConfig | ||
|
||
if __name__ == '__main__': | ||
from transformers import AutoTokenizer | ||
|
||
embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5" | ||
embedding_model = SBertEmbeddingModel(embedding_model_path) | ||
tokenizer = AutoTokenizer.from_pretrained(embedding_model_path) | ||
retriever_config = FaissRetrieverConfig( | ||
max_tokens=100, | ||
max_context_tokens=3500, | ||
use_top_k=True, | ||
embedding_model=embedding_model, | ||
top_k=5, | ||
tokenizer=tokenizer, | ||
embedding_model_string="bge-large-zh-v1.5", | ||
index_path="faiss_index.bin", | ||
rebuild_index=True | ||
) | ||
|
||
faiss_retriever = FaissRetriever(config=retriever_config) | ||
|
||
documents = [] | ||
with open('/home/test/codes/GoMate/data/zh_refine.json', 'r', encoding="utf-8") as f: | ||
for line in f.readlines(): | ||
data = json.loads(line) | ||
documents.extend(data['positive']) | ||
documents.extend(data['negative']) | ||
print(len(documents)) | ||
faiss_retriever.build_from_texts(documents[:200]) | ||
|
||
contexts = faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁") | ||
print(contexts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
@author: yanqiangmiffy | ||
@contact:[email protected] | ||
@license: Apache Licence | ||
@time: 2024/6/5 23:07 | ||
""" | ||
import pandas as pd | ||
from transformers import AutoTokenizer, AutoModel | ||
import torch | ||
import torch.nn.functional as F | ||
import faiss | ||
import numpy as np | ||
import os | ||
from tqdm import tqdm | ||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" | ||
class SemanticEmbedding: | ||
|
||
def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'): | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.model = AutoModel.from_pretrained(model_name) | ||
|
||
# Mean Pooling - Take attention mask into account for correct averaging | ||
def mean_pooling(self, model_output, attention_mask): | ||
token_embeddings = model_output[0] # First element of model_output contains all token embeddings | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
|
||
def get_embedding(self, sentences): | ||
# Tokenize sentences | ||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') | ||
with torch.no_grad(): | ||
model_output = self.model(**encoded_input) | ||
# Perform pooling | ||
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) | ||
|
||
# Normalize embeddings | ||
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) | ||
return sentence_embeddings.detach().numpy() | ||
class FaissIdx: | ||
|
||
def __init__(self, model, dim=768): | ||
self.index = faiss.IndexFlatIP(dim) | ||
# Maintaining the document data | ||
self.doc_map = dict() | ||
self.model = model | ||
self.ctr = 0 | ||
|
||
def add_doc(self, document_text): | ||
self.index.add(self.model.get_embedding(document_text)) | ||
self.doc_map[self.ctr] = document_text # store the original document text | ||
self.ctr += 1 | ||
|
||
def search_doc(self, query, k=3): | ||
D, I = self.index.search(self.model.get_embedding(query), k) | ||
return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters