Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature@chunk:token chat setence sematic #88

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions trustrag/modules/chunks/char_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,45 @@
from typing import List

class CharChunker(BaseChunker):
def __init__(self,chunk_size:int=64):
"""
A character-based chunker that splits input texts into fixed-size chunks of characters.

This class inherits from `BaseChunker` and implements the `get_chunks` method to divide
input texts into smaller chunks, where each chunk contains a specified number of characters.
This is useful for processing long texts in smaller, manageable pieces.

Attributes:
chunk_size (int): The number of characters per chunk. Defaults to 64.
"""

def __init__(self, chunk_size: int = 64) -> None:
"""
Initializes the CharChunker with a specified chunk size.

Args:
chunk_size (int): The number of characters per chunk. Defaults to 64.
"""
super().__init__()
self.chunk_size = chunk_size
def get_chunks(self, texts:List[str]) -> List[str]:
chunks= []
for paragraph in texts:

def get_chunks(self, paragraphs: List[str]) -> List[str]:
"""
Splits the input paragraphs into chunks of characters based on the specified chunk size.

Args:
paragraphs (List[str]): A list of strings (paragraph) to be chunked.

Returns:
List[str]: A list of chunks, where each chunk is a string of characters.
"""
chunks = []
for paragraph in paragraphs:
for i in range(0, len(paragraph), self.chunk_size):
chunk = paragraph[i:i + self.chunk_size]
chunks.append(chunk)
return chunks


if __name__ == "__main__":
cc=CharChunker(chunk_size=64)
cc = CharChunker(chunk_size=64)
print(cc.get_chunks(["我喜欢北京。"]))
197 changes: 197 additions & 0 deletions trustrag/modules/chunks/semantic_chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import re
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from trustrag.modules.document import rag_tokenizer
from trustrag.modules.chunks.base import BaseChunker
from sentence_transformers import SentenceTransformer
from langchain.embeddings import OpenAIEmbeddings

class SemanticChunker(BaseChunker):
"""
A class for splitting text into chunks based on semantic similarity between sentences.

This class uses sentence embeddings to calculate the semantic similarity between sentences
and groups them into chunks based on a similarity threshold. It ensures that each chunk
contains semantically related sentences.

Attributes:
tokenizer (callable): A tokenizer function used to count tokens in sentences.
chunk_size (int): The maximum number of tokens allowed per chunk.
similarity_threshold (float): The threshold for semantic similarity to group sentences.
embeddings_model: The embedding model used to generate sentence embeddings.
Can be either OpenAIEmbeddings or SentenceTransformer.
"""

def __init__(self, chunk_size=512, similarity_threshold=0.8, embedding_model="sentence-transformers", model_name="all-MiniLM-L6-v2"):
"""
Initializes the SemanticChunker with a tokenizer, chunk size, similarity threshold, and embedding model.

Args:
chunk_size (int, optional): The maximum number of tokens allowed per chunk. Defaults to 512.
similarity_threshold (float, optional): The threshold for semantic similarity to group sentences. Defaults to 0.8.
embedding_model (str, optional): The embedding model to use. Options: "sentence-transformers" or "openai". Defaults to "sentence-transformers".
model_name (str, optional): The name of the model to use. For "sentence-transformers", it's the model name (e.g., "all-MiniLM-L6-v2").
For "openai", it's the model name (e.g., "text-embedding-ada-002"). Defaults to "all-MiniLM-L6-v2".
"""
super().__init__()
self.tokenizer = rag_tokenizer
self.chunk_size = chunk_size
self.similarity_threshold = similarity_threshold

if embedding_model == "sentence-transformers":
self.embeddings_model = SentenceTransformer(model_name)
elif embedding_model == "openai":
self.embeddings_model = OpenAIEmbeddings(model=model_name)
else:
raise ValueError("Invalid embedding_model. Choose 'sentence-transformers' or 'openai'.")

def split_sentences(self, text: str) -> list[str]:
"""
Splits the input text into sentences based on Chinese and English punctuation marks.

Args:
text (str): The input text to be split into sentences.

Returns:
list[str]: A list of sentences extracted from the input text.
"""
# Use regex to split text by sentence-ending punctuation marks
sentence_endings = re.compile(r'([。!?.!?])')
sentences = sentence_endings.split(text)

# Merge punctuation marks with their preceding sentences
result = []
for i in range(0, len(sentences) - 1, 2):
if sentences[i]:
result.append(sentences[i] + sentences[i + 1])

# Handle the last sentence if it lacks punctuation
if sentences[-1]:
result.append(sentences[-1])

# Remove whitespace and filter out empty sentences
result = [sentence.strip() for sentence in result if sentence.strip()]

return result

def get_sentence_embeddings(self, sentences: list[str]) -> list[list[float]]:
"""
Generates embeddings for a list of sentences using the selected embedding model.

Args:
sentences (list[str]): A list of sentences to generate embeddings for.

Returns:
list[list[float]]: A list of sentence embeddings.
"""
if isinstance(self.embeddings_model, SentenceTransformer):
return self.embeddings_model.encode(sentences)
elif isinstance(self.embeddings_model, OpenAIEmbeddings):
return self.embeddings_model.embed_documents(sentences)
else:
raise ValueError("Unsupported embedding model.")

def calculate_cosine_distances(self, embeddings: list[list[float]]) -> list[float]:
"""
Calculates the cosine distances between consecutive sentence embeddings.

Args:
embeddings (list[list[float]]): A list of sentence embeddings.

Returns:
list[float]: A list of cosine distances between consecutive sentences.
"""
distances = []
for i in range(len(embeddings) - 1):
similarity = cosine_similarity([embeddings[i]], [embeddings[i + 1]])[0][0]
distance = 1 - similarity
distances.append(distance)
return distances

def get_chunks(self, paragraphs: list[str]) -> list[str]:
"""
Splits a list of paragraphs into chunks based on semantic similarity and token size.

Args:
paragraphs (list[str]|str): A list of paragraphs to be chunked.

Returns:
list[str]: A list of text chunks, each containing semantically related sentences.
"""
# Combine paragraphs into a single text
text = ''.join(paragraphs)

# Split the text into sentences
sentences = self.split_sentences(text)

# If no sentences are found, treat paragraphs as sentences
if len(sentences) == 0:
sentences = paragraphs

# Generate embeddings for sentences
embeddings = self.get_sentence_embeddings(sentences)

# Calculate cosine distances between consecutive sentences
distances = self.calculate_cosine_distances(embeddings)

# Determine breakpoints based on the similarity threshold
breakpoint_indices = [i for i, distance in enumerate(distances) if distance > (1 - self.similarity_threshold)]

# Combine sentences into chunks
chunks = []
start_index = 0
for index in breakpoint_indices:
end_index = index
group = sentences[start_index:end_index + 1]
combined_text = ' '.join(group)
chunks.append(combined_text)
start_index = index + 1

# Add the last chunk if it contains any sentences
if start_index < len(sentences):
combined_text = ' '.join(sentences[start_index:])
chunks.append(combined_text)

# Preprocess the chunks to normalize formatting
chunks = self.process_text_chunks(chunks)
return chunks

def process_text_chunks(self, chunks: list[str]) -> list[str]:
"""
Preprocesses text chunks by normalizing excessive newlines and spaces.

Args:
chunks (list[str]): A list of text chunks to be processed.

Returns:
list[str]: A list of processed text chunks with normalized formatting.
"""
processed_chunks = []
for chunk in chunks:
# Normalize four or more consecutive newlines
while '\n\n\n\n' in chunk:
chunk = chunk.replace('\n\n\n\n', '\n\n')

# Normalize four or more consecutive spaces
while ' ' in chunk:
chunk = chunk.replace(' ', ' ')

processed_chunks.append(chunk)

return processed_chunks

if __name__ == '__main__':
with open("../../../data/docs/news.txt", "r", encoding="utf-8") as f:
content = f.read()

# Example 1: Use SentenceTransformer
sc_st = SemanticChunker(embedding_model="sentence-transformers", model_name="all-MiniLM-L6-v2")
chunks_st = sc_st.get_chunks([content])
for chunk in chunks_st:
print(f"SentenceTransformer Chunk:\n{chunk}")

# # Example 2: Use OpenAIEmbeddings
# sc_openai = SemanticChunker(embedding_model="openai", model_name="text-embedding-ada-002")
# chunks_openai = sc_openai.get_chunks([content])
# for chunk in chunks_openai:
# print(f"OpenAIEmbeddings Chunk:\n{chunk}")
46 changes: 38 additions & 8 deletions trustrag/modules/chunks/sentence_chunk.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,41 @@
import re

from chonkie import BaseChunker
from tqdm import tqdm
from trustrag.modules.document import rag_tokenizer
from trustrag.modules.chunks.base import BaseChunker

class SentenceChunker(BaseChunker):
def __init__(self,chunk_size=512):
"""
A class for splitting text into chunks based on sentences, ensuring each chunk does not exceed a specified token size.

This class is designed to handle both Chinese and English text, splitting it into sentences using punctuation marks.
It then groups these sentences into chunks, ensuring that the total number of tokens in each chunk does not exceed
the specified `chunk_size`. The class also provides methods to preprocess the text chunks by normalizing excessive
newlines and spaces.

Attributes:
tokenizer (callable): A tokenizer function used to count tokens in sentences.
chunk_size (int): The maximum number of tokens allowed per chunk.

Methods:
split_sentences(text: str) -> list[str]:
Splits the input text into sentences based on Chinese and English punctuation marks.

process_text_chunks(chunks: list[str]) -> list[str]:
Preprocesses text chunks by normalizing excessive newlines and spaces.

get_chunks(paragraphs: list[str]) -> list[str]:
Splits a list of paragraphs into chunks based on a specified token size.
"""

def __init__(self, chunk_size=512):
"""
Initializes the TextChunker with a tokenizer.
Initializes the SentenceChunker with a tokenizer and a specified chunk size.

Args:
chunk_size (int, optional): The maximum number of tokens allowed per chunk. Defaults to 512.
"""
super().__init__()
self.tokenizer = rag_tokenizer
self.chunk_size=chunk_size
self.chunk_size = chunk_size

def split_sentences(self, text: str) -> list[str]:
"""
Expand Down Expand Up @@ -73,7 +96,6 @@ def get_chunks(self, paragraphs: list[str]) -> list[str]:

Args:
paragraphs (list[str]|str): A list of paragraphs to be chunked.
chunk_size (int): The maximum number of tokens allowed per chunk.

Returns:
list[str]: A list of text chunks, each containing sentences that fit within the token limit.
Expand Down Expand Up @@ -111,4 +133,12 @@ def get_chunks(self, paragraphs: list[str]) -> list[str]:

# Preprocess the chunks to normalize formatting
chunks = self.process_text_chunks(chunks)
return chunks
return chunks

if __name__ == '__main__':
with open("../../../data/docs/news.txt","r",encoding="utf-8") as f:
content=f.read()
tc=SentenceChunker(chunk_size=128)
chunks = tc.get_chunks([content])
for chunk in chunks:
print(f"Chunk Content:\n{chunk}")
Loading
Loading