From 2031e05a0526671dab334faccf09c9b34d0e1589 Mon Sep 17 00:00:00 2001 From: yanqiangmiffy <1185918903@qq.com> Date: Wed, 5 Jun 2024 23:42:45 +0800 Subject: [PATCH] features@add DenseRetriever --- data/README.md | 1 + examples/retrievers/faiss_example.py | 88 +++++++++++++++++++ examples/retrievers/faissretriever_example.py | 45 ++++++++++ gomate/modules/retrieval/dense_retriever.py | 58 ++++++++++++ gomate/modules/retrieval/faiss_retriever.py | 5 -- 5 files changed, 192 insertions(+), 5 deletions(-) create mode 100644 data/README.md create mode 100644 examples/retrievers/faiss_example.py create mode 100644 examples/retrievers/faissretriever_example.py create mode 100644 gomate/modules/retrieval/dense_retriever.py diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..093a209 --- /dev/null +++ b/data/README.md @@ -0,0 +1 @@ +https://github.com/chen700564/RGB/blob/master/data/zh_refine.json \ No newline at end of file diff --git a/examples/retrievers/faiss_example.py b/examples/retrievers/faiss_example.py new file mode 100644 index 0000000..c62b674 --- /dev/null +++ b/examples/retrievers/faiss_example.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +# https://deepnote.com/blog/semantic-search-using-faiss-and-mpnet +@author: yanqiangmiffy +@contact:1185918903@qq.com +@license: Apache Licence +@time: 2024/6/5 22:37 + +""" +import pandas as pd +from transformers import AutoTokenizer, AutoModel +import torch +import torch.nn.functional as F +import faiss +import numpy as np +import os +from tqdm import tqdm +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +class SemanticEmbedding: + + def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name) + + # Mean Pooling - Take attention mask into account for correct averaging + def mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] # First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def get_embedding(self, sentences): + # Tokenize sentences + encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + model_output = self.model(**encoded_input) + # Perform pooling + sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings.detach().numpy() +class FaissIdx: + + def __init__(self, model, dim=768): + self.index = faiss.IndexFlatIP(dim) + # Maintaining the document data + self.doc_map = dict() + self.model = model + self.ctr = 0 + + def add_doc(self, document_text): + self.index.add(self.model.get_embedding(document_text)) + self.doc_map[self.ctr] = document_text # store the original document text + self.ctr += 1 + + def search_doc(self, query, k=3): + D, I = self.index.search(self.model.get_embedding(query), k) + return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map] +if __name__ == '__main__': + model = SemanticEmbedding(r'I:\pretrained_models\bert\english\paraphrase-multilingual-mpnet-base-v2') + a = model.get_embedding("我喜欢打篮球") + print(a) + print(a.shape) + + index = FaissIdx(model) + index.add_doc("笔记本电脑") + index.add_doc("医生的办公室") + result=index.search_doc("个人电脑") + print(result) + + + # 加载测试文档 + data=pd.read_json('../../data/zh_refine.json', lines=True)[:50] + print(data) + print(data.columns) + + for documents in tqdm(data['positive'],total=len(data)): + for document in documents: + index.add_doc(document) + + for documents in tqdm(data['negative'],total=len(data)): + for document in documents: + index.add_doc(document) + + result=index.search_doc("2022年特斯拉交付量") + print(result) diff --git a/examples/retrievers/faissretriever_example.py b/examples/retrievers/faissretriever_example.py new file mode 100644 index 0000000..637b33c --- /dev/null +++ b/examples/retrievers/faissretriever_example.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +@author: yanqiangmiffy +@contact:1185918903@qq.com +@license: Apache Licence +@time: 2024/6/5 22:36 +""" +import json + +from gomate.modules.retrieval.embedding import SBertEmbeddingModel +from gomate.modules.retrieval.faiss_retriever import FaissRetriever, FaissRetrieverConfig + +if __name__ == '__main__': + from transformers import AutoTokenizer + + embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5" + embedding_model = SBertEmbeddingModel(embedding_model_path) + tokenizer = AutoTokenizer.from_pretrained(embedding_model_path) + retriever_config = FaissRetrieverConfig( + max_tokens=100, + max_context_tokens=3500, + use_top_k=True, + embedding_model=embedding_model, + top_k=5, + tokenizer=tokenizer, + embedding_model_string="bge-large-zh-v1.5", + index_path="faiss_index.bin", + rebuild_index=True + ) + + faiss_retriever = FaissRetriever(config=retriever_config) + + documents = [] + with open('/home/test/codes/GoMate/data/zh_refine.json', 'r', encoding="utf-8") as f: + for line in f.readlines(): + data = json.loads(line) + documents.extend(data['positive']) + documents.extend(data['negative']) + print(len(documents)) + faiss_retriever.build_from_texts(documents[:200]) + + contexts = faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁") + print(contexts) diff --git a/gomate/modules/retrieval/dense_retriever.py b/gomate/modules/retrieval/dense_retriever.py new file mode 100644 index 0000000..aca6e79 --- /dev/null +++ b/gomate/modules/retrieval/dense_retriever.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +@author: yanqiangmiffy +@contact:1185918903@qq.com +@license: Apache Licence +@time: 2024/6/5 23:07 +""" +import pandas as pd +from transformers import AutoTokenizer, AutoModel +import torch +import torch.nn.functional as F +import faiss +import numpy as np +import os +from tqdm import tqdm +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +class SemanticEmbedding: + + def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name) + + # Mean Pooling - Take attention mask into account for correct averaging + def mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] # First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def get_embedding(self, sentences): + # Tokenize sentences + encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + model_output = self.model(**encoded_input) + # Perform pooling + sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings.detach().numpy() +class FaissIdx: + + def __init__(self, model, dim=768): + self.index = faiss.IndexFlatIP(dim) + # Maintaining the document data + self.doc_map = dict() + self.model = model + self.ctr = 0 + + def add_doc(self, document_text): + self.index.add(self.model.get_embedding(document_text)) + self.doc_map[self.ctr] = document_text # store the original document text + self.ctr += 1 + + def search_doc(self, query, k=3): + D, I = self.index.search(self.model.get_embedding(query), k) + return [{self.doc_map[idx]: score} for idx, score in zip(I[0], D[0]) if idx in self.doc_map] \ No newline at end of file diff --git a/gomate/modules/retrieval/faiss_retriever.py b/gomate/modules/retrieval/faiss_retriever.py index c6b635c..898650e 100644 --- a/gomate/modules/retrieval/faiss_retriever.py +++ b/gomate/modules/retrieval/faiss_retriever.py @@ -166,11 +166,6 @@ def build_from_texts(self, documents): if self.index is None and self.all_embeddings: self.index = faiss.IndexFlatIP(self.all_embeddings[0].shape[1]) - # first_shape = self.all_embeddings[0].shape - # for embedding in self.all_embeddings: - # if embedding.shape != first_shape: - # print("Found an embedding with a different shape:", embedding.shape) - self.all_embeddings = np.vstack(self.all_embeddings) print(self.all_embeddings.shape) print(len(self.context_chunks))