Skip to content

Commit

Permalink
Merge pull request #13 from gomate-community/pipeline
Browse files Browse the repository at this point in the history
features@add DenseRetriever
  • Loading branch information
yanqiangmiffy authored Jun 5, 2024
2 parents b2cdd86 + 2031e05 commit 78896d6
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 5 deletions.
1 change: 1 addition & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://github.com/chen700564/RGB/blob/master/data/zh_refine.json
88 changes: 88 additions & 0 deletions examples/retrievers/faiss_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
# https://deepnote.com/blog/semantic-search-using-faiss-and-mpnet
@author: yanqiangmiffy
@contact:[email protected]
@license: Apache Licence
@time: 2024/6/5 22:37
"""
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import faiss
import numpy as np
import os
from tqdm import tqdm
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
class SemanticEmbedding:

def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embedding(self, sentences):
# Tokenize sentences
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])

# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.detach().numpy()
class FaissIdx:

def __init__(self, model, dim=768):
self.index = faiss.IndexFlatIP(dim)
# Maintaining the document data
self.doc_map = dict()
self.model = model
self.ctr = 0

def add_doc(self, document_text):
self.index.add(self.model.get_embedding(document_text))
self.doc_map[self.ctr] = document_text # store the original document text
self.ctr += 1

def search_doc(self, query, k=3):
D, I = self.index.search(self.model.get_embedding(query), k)
return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map]
if __name__ == '__main__':
model = SemanticEmbedding(r'I:\pretrained_models\bert\english\paraphrase-multilingual-mpnet-base-v2')
a = model.get_embedding("我喜欢打篮球")
print(a)
print(a.shape)

index = FaissIdx(model)
index.add_doc("笔记本电脑")
index.add_doc("医生的办公室")
result=index.search_doc("个人电脑")
print(result)


# 加载测试文档
data=pd.read_json('../../data/zh_refine.json', lines=True)[:50]
print(data)
print(data.columns)

for documents in tqdm(data['positive'],total=len(data)):
for document in documents:
index.add_doc(document)

for documents in tqdm(data['negative'],total=len(data)):
for document in documents:
index.add_doc(document)

result=index.search_doc("2022年特斯拉交付量")
print(result)
45 changes: 45 additions & 0 deletions examples/retrievers/faissretriever_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@author: yanqiangmiffy
@contact:[email protected]
@license: Apache Licence
@time: 2024/6/5 22:36
"""
import json

from gomate.modules.retrieval.embedding import SBertEmbeddingModel
from gomate.modules.retrieval.faiss_retriever import FaissRetriever, FaissRetrieverConfig

if __name__ == '__main__':
from transformers import AutoTokenizer

embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5"
embedding_model = SBertEmbeddingModel(embedding_model_path)
tokenizer = AutoTokenizer.from_pretrained(embedding_model_path)
retriever_config = FaissRetrieverConfig(
max_tokens=100,
max_context_tokens=3500,
use_top_k=True,
embedding_model=embedding_model,
top_k=5,
tokenizer=tokenizer,
embedding_model_string="bge-large-zh-v1.5",
index_path="faiss_index.bin",
rebuild_index=True
)

faiss_retriever = FaissRetriever(config=retriever_config)

documents = []
with open('/home/test/codes/GoMate/data/zh_refine.json', 'r', encoding="utf-8") as f:
for line in f.readlines():
data = json.loads(line)
documents.extend(data['positive'])
documents.extend(data['negative'])
print(len(documents))
faiss_retriever.build_from_texts(documents[:200])

contexts = faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁")
print(contexts)
58 changes: 58 additions & 0 deletions gomate/modules/retrieval/dense_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@author: yanqiangmiffy
@contact:[email protected]
@license: Apache Licence
@time: 2024/6/5 23:07
"""
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import faiss
import numpy as np
import os
from tqdm import tqdm
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
class SemanticEmbedding:

def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embedding(self, sentences):
# Tokenize sentences
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])

# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.detach().numpy()
class FaissIdx:

def __init__(self, model, dim=768):
self.index = faiss.IndexFlatIP(dim)
# Maintaining the document data
self.doc_map = dict()
self.model = model
self.ctr = 0

def add_doc(self, document_text):
self.index.add(self.model.get_embedding(document_text))
self.doc_map[self.ctr] = document_text # store the original document text
self.ctr += 1

def search_doc(self, query, k=3):
D, I = self.index.search(self.model.get_embedding(query), k)
return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map]
5 changes: 0 additions & 5 deletions gomate/modules/retrieval/faiss_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,6 @@ def build_from_texts(self, documents):
if self.index is None and self.all_embeddings:
self.index = faiss.IndexFlatIP(self.all_embeddings[0].shape[1])

# first_shape = self.all_embeddings[0].shape
# for embedding in self.all_embeddings:
# if embedding.shape != first_shape:
# print("Found an embedding with a different shape:", embedding.shape)

self.all_embeddings = np.vstack(self.all_embeddings)
print(self.all_embeddings.shape)
print(len(self.context_chunks))
Expand Down

0 comments on commit 78896d6

Please sign in to comment.