Skip to content

Commit

Permalink
Merge pull request #12 from gomate-community/pipeline
Browse files Browse the repository at this point in the history
feature@add faiss retriever
  • Loading branch information
yanqiangmiffy authored Jun 5, 2024
2 parents 999ae1d + df15172 commit b2cdd86
Show file tree
Hide file tree
Showing 8 changed files with 652 additions and 50 deletions.
300 changes: 300 additions & 0 deletions data/zh_refine.json

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion examples/retrievers/bm5retriever_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@license: Apache Licence
@time: 2024/6/1 15:48
"""
import os
from gomate.modules.retrieval.bm25_retriever import BM25RetrieverConfig, BM25Retriever, tokenizer

if __name__ == '__main__':
Expand All @@ -22,6 +23,8 @@
corpus = [

]
root_dir = os.path.abspath(os.path.dirname(__file__))
print(root_dir)
new_files = [
r'H:\Projects\GoMate\data\伊朗.txt',
r'H:\Projects\GoMate\data\伊朗总统罹难事件.txt',
Expand All @@ -31,7 +34,7 @@
for filename in new_files:
with open(filename, 'r', encoding="utf-8") as file:
corpus.append(file.read())
bm25_retriever.fit_bm25(corpus)
bm25_retriever.build_from_texts(corpus)
query = "伊朗总统莱希"
search_docs = bm25_retriever.retrieve(query)
print(search_docs)
2 changes: 1 addition & 1 deletion gomate/modules/retrieval/bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(self, config):
self.delta = config.delta
self.algorithm = config.algorithm

def fit_bm25(self, corpus):
def build_from_texts(self, corpus):
self.corpus=corpus
if self.algorithm == 'Okapi':
self.bm25 = BM25Okapi(corpus=corpus, tokenizer=self.tokenizer, k1=self.k1, b=self.b, epsilon=self.epsilon)
Expand Down
2 changes: 1 addition & 1 deletion gomate/modules/retrieval/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, model_name="sentence-transformers/multi-qa-mpnet-base-cos-v1"
self.model = SentenceTransformer(model_name)

def create_embedding(self, text):
return self.model.encode(text)
return self.model.encode(text,show_progress_bar=False)


class BaseEmbeddings:
Expand Down
157 changes: 112 additions & 45 deletions gomate/modules/retrieval/faiss_retriever.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import json
import os
import random
from concurrent.futures import ProcessPoolExecutor
from typing import List, Any

import faiss
import numpy as np
import tiktoken
from tqdm import tqdm

from gomate.modules.retrieval.embedding import BaseEmbeddingModel, OpenAIEmbeddingModel
from gomate.modules.retrieval.embedding import SBertEmbeddingModel
from gomate.modules.retrieval.retrievers import BaseRetriever
from gomate.modules.retrieval.utils import split_text

Expand All @@ -16,12 +20,14 @@ def __init__(
self,
max_tokens=100,
max_context_tokens=3500,
use_top_k=False,
use_top_k=True,
embedding_model=None,
question_embedding_model=None,
top_k=5,
tokenizer=tiktoken.get_encoding("cl100k_base"),
tokenizer=None,
embedding_model_string=None,
index_path=None,
rebuild_index=True
):
if max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
Expand Down Expand Up @@ -54,6 +60,8 @@ def __init__(
self.question_embedding_model = question_embedding_model or self.embedding_model
self.tokenizer = tokenizer
self.embedding_model_string = embedding_model_string or "OpenAI"
self.index_path = index_path
self.rebuild_index=rebuild_index

def log_config(self):
config_summary = """
Expand All @@ -66,6 +74,8 @@ def log_config(self):
Top K: {top_k}
Tokenizer: {tokenizer}
Embedding Model String: {embedding_model_string}
Index Path: {index_path}
Rebuild Index Path: {rebuild_index}
""".format(
max_tokens=self.max_tokens,
max_context_tokens=self.max_context_tokens,
Expand All @@ -75,6 +85,8 @@ def log_config(self):
top_k=self.top_k,
tokenizer=self.tokenizer,
embedding_model_string=self.embedding_model_string,
index_path=self.index_path,
rebuild_index=self.rebuild_index
)
return config_summary

Expand All @@ -90,60 +102,81 @@ def __init__(self, config):
self.embedding_model = config.embedding_model
self.question_embedding_model = config.question_embedding_model
self.index = None
self.context_chunks = None
self.context_chunks = []
self.max_tokens = config.max_tokens
self.max_context_tokens = config.max_context_tokens
self.use_top_k = config.use_top_k
self.tokenizer = config.tokenizer
self.top_k = config.top_k
self.embedding_model_string = config.embedding_model_string
self.index_path = config.index_path
self.rebuild_index=config.rebuild_index
# Load the index from the specified path if it is not None
if not self.rebuild_index:
if self.index_path and os.path.exists(self.index_path):
self.load_index(self.index_path)
else:
os.remove(self.index_path)

def build_from_text(self, doc_text):
def load_index(self, index_path):
"""
Builds the index from a given text.
Loads a Faiss index from a specified path.
:param doc_text: A string containing the document text.
:param tokenizer: A tokenizer used to split the text into chunks.
:param max_tokens: An integer representing the maximum number of tokens per chunk.
:param index_path: Path to the Faiss index file.
"""
self.context_chunks = np.array(
split_text(doc_text, self.tokenizer, self.max_tokens)
)

with ProcessPoolExecutor() as executor:
futures = [
executor.submit(self.embedding_model.create_embedding, context_chunk)
for context_chunk in self.context_chunks
]

self.embeddings = []
for future in tqdm(futures, total=len(futures), desc="Building embeddings"):
self.embeddings.append(future.result())

self.embeddings = np.array(self.embeddings, dtype=np.float32)

self.index = faiss.IndexFlatIP(self.embeddings.shape[1])
self.index.add(self.embeddings)
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
print("Index loaded successfully.")
else:
print("Index path does not exist.")

def build_from_leaf_nodes(self, leaf_nodes):
def encode_document(self, doc_text):
"""
Builds the index from a given text.
:param doc_text: A string containing the document text.
:param tokenizer: A tokenizer used to split the text into chunks.
:param max_tokens: An integer representing the maximum number of tokens per chunk.
"""

self.context_chunks = [node.text for node in leaf_nodes]

self.embeddings = np.array(
[node.embeddings[self.embedding_model_string] for node in leaf_nodes],
dtype=np.float32,
# Split the text into context chunks
context_chunks = np.array(
split_text(doc_text, self.tokenizer, self.max_tokens)
)
# Collect embeddings using a for loop
embeddings = []
for context_chunk in context_chunks:
embedding = self.embedding_model.create_embedding(context_chunk)
embeddings.append(embedding)

embeddings = np.array(embeddings, dtype=np.float32)
return embeddings,context_chunks.tolist()
def build_from_texts(self, documents):
"""
Processes multiple documents in batches, builds the index, and saves it to disk.
self.index = faiss.IndexFlatIP(self.embeddings.shape[1])
self.index.add(self.embeddings)

:param documents: List of document texts to process.
:param save_path: Path to save the index file.
:param batch_size: Number of documents to process in each batch.
"""
self.all_embeddings = []
self.context_chunks=[]
for i in tqdm(range(0, len(documents))):
doc_embeddings,context_chunks = self.encode_document(documents[i])
self.all_embeddings.append(doc_embeddings)
self.context_chunks.extend(context_chunks)
# Initialize the index only once
if self.index is None and self.all_embeddings:
self.index = faiss.IndexFlatIP(self.all_embeddings[0].shape[1])

# first_shape = self.all_embeddings[0].shape
# for embedding in self.all_embeddings:
# if embedding.shape != first_shape:
# print("Found an embedding with a different shape:", embedding.shape)

self.all_embeddings = np.vstack(self.all_embeddings)
print(self.all_embeddings.shape)
print(len(self.context_chunks))
self.index.add(self.all_embeddings)
# Save the index to disk
faiss.write_index(self.index, self.index_path)
def sanity_check(self, num_samples=4):
"""
Perform a sanity check by recomputing embeddings of a few randomly-selected chunks.
Expand All @@ -153,7 +186,7 @@ def sanity_check(self, num_samples=4):
indices = random.sample(range(len(self.context_chunks)), num_samples)

for i in indices:
original_embedding = self.embeddings[i]
original_embedding = self.all_embeddings[i]
recomputed_embedding = self.embedding_model.create_embedding(
self.context_chunks[i]
)
Expand All @@ -163,7 +196,7 @@ def sanity_check(self, num_samples=4):

print(f"Sanity check passed for {num_samples} random samples.")

def retrieve(self, query: str) -> str:
def retrieve(self, query: str) -> list[Any]:
"""
Retrieves the k most similar context chunks for a given query.
Expand All @@ -180,22 +213,56 @@ def retrieve(self, query: str) -> str:
]
)

context = ""
context = []

if self.use_top_k:
_, indices = self.index.search(query_embedding, self.top_k)
distances, indices = self.index.search(query_embedding, self.top_k)
print(distances,indices)
print(distances[0][2],indices)
for i in range(self.top_k):
context += self.context_chunks[indices[0][i]]

context.append({'text':self.context_chunks[indices[0][i]],'score':distances[0][i]})
else:
range_ = int(self.max_context_tokens / self.max_tokens)
_, indices = self.index.search(query_embedding, range_)
total_tokens = 0
for i in range(range_):
tokens = len(self.tokenizer.encode(self.context_chunks[indices[0][i]]))
context += self.context_chunks[indices[0][i]]
context.append(self.context_chunks[indices[0][i]])
if total_tokens + tokens > self.max_context_tokens:
break
total_tokens += tokens

return context


if __name__ == '__main__':
from transformers import AutoTokenizer

embedding_model_path = "/home/test/pretrained_models/bge-large-zh-v1.5"
embedding_model = SBertEmbeddingModel(embedding_model_path)
tokenizer = AutoTokenizer.from_pretrained(embedding_model_path)
retriever_config = FaissRetrieverConfig(
max_tokens=100,
max_context_tokens=3500,
use_top_k=True,
embedding_model=embedding_model,
top_k=5,
tokenizer=tokenizer,
embedding_model_string="bge-large-zh-v1.5",
index_path="faiss_index.bin",
rebuild_index=True
)

faiss_retriever=FaissRetriever(config=retriever_config)

documents=[]
with open('/home/test/codes/GoMate/data/zh_refine.json','r',encoding="utf-8") as f:
for line in f.readlines():
data=json.loads(line)
documents.extend(data['positive'])
documents.extend(data['negative'])
print(len(documents))
faiss_retriever.build_from_texts(documents[:200])

contexts=faiss_retriever.retrieve("2022年冬奥会开幕式总导演是谁")
print(contexts)
Loading

0 comments on commit b2cdd86

Please sign in to comment.