Skip to content

Commit

Permalink
handling some more edge cases, when split is still too big and all se…
Browse files Browse the repository at this point in the history
…parators ran
  • Loading branch information
davidsbatista committed Jan 9, 2025
1 parent 951956b commit e1464eb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 45 deletions.
54 changes: 21 additions & 33 deletions haystack/components/preprocessors/recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _split_chunk(self, current_chunk: str) -> Tuple[str, str]:
"""

if self.split_units == "word":
words = current_chunk.split(" ")
words = current_chunk.split()
current_chunk = " ".join(words[: self.split_length])
remaining_words = words[self.split_length :]
return current_chunk, " ".join(remaining_words)
Expand Down Expand Up @@ -201,45 +201,30 @@ def _apply_overlap(self, chunks: List[str]) -> List[str]:

return overlapped_chunks

def _get_overlap(self, overlapped_chunks):
def _get_overlap(self, overlapped_chunks: List[str]) -> Tuple[str, str]:
"""Get the previous overlapped chunk instead of the original chunk."""
prev_chunk = overlapped_chunks[-1]
overlap_start = max(0, self._chunk_length(prev_chunk) - self.split_overlap)
if self.split_units == "word":
word_chunks = prev_chunk.split(" ")
word_chunks = prev_chunk.split()
overlap = " ".join(word_chunks[overlap_start:])
else:
overlap = prev_chunk[overlap_start:]
return overlap, prev_chunk

def _chunk_length(self, text: str) -> int:
"""
Get the length of the chunk in words or characters.
Split the text by whitespace and count non-empty elements.
:param text: The text to be split into chunks.
:returns:
The length of the chunk in words or characters.
:param: The text to be split.
:return: The number of words in the text.
"""

if self.split_units == "word":
return len(text.split(" "))
else:
return len(text)

# def _chunk_length(self, text: str) -> int:
# """
# Split the text by whitespace and count non-empty elements
# Count newline and form feed characters
#
# :param text:
# :return:
# """
#
# if self.split_units == "word":
# words = [word for word in text.split() if word]
# special_chars = text.count('\n') + text.count('\f') + text.count('\x0c')
# return len(words) + special_chars
#
# return len(text)
words = [word for word in text.split(" ") if word]
return len(words)

return len(text)

def _chunk_text(self, text: str) -> List[str]:
"""
Expand Down Expand Up @@ -299,10 +284,13 @@ def _chunk_text(self, text: str) -> List[str]:
# recursively handle splits that are too large
if self._chunk_length(split_text) > self.split_length:
if curr_separator == self.separators[-1]:
# tried the last separator, can't split further, break the loop and fall back to
# word- or character-level chunking
return self.fall_back_to_fixed_chunking(text, self.split_units)
chunks.extend(self._chunk_text(split_text))
# tried last separator, can't split further, do a fixed-split based on word/character
fall_back_chunks = self._fall_back_to_fixed_chunking(split_text, self.split_units)
chunks.extend(fall_back_chunks)
else:
chunks.extend(self._chunk_text(split_text))
current_length += self._chunk_length(split_text)

else:
current_chunk.append(split_text)
current_length += self._chunk_length(split_text)
Expand All @@ -320,9 +308,9 @@ def _chunk_text(self, text: str) -> List[str]:
return chunks

# if no separator worked, fall back to word- or character-level chunking
return self.fall_back_to_fixed_chunking(text, self.split_units)
return self._fall_back_to_fixed_chunking(text, self.split_units)

def fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char"]) -> List[str]:
def _fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char"]) -> List[str]:
"""
Fall back to a fixed chunking approach if no separator works for the text.
Expand All @@ -336,7 +324,7 @@ def fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "c

if split_units == "word":
words = text.split(" ")
for i in range(0, self._chunk_length(text), step):
for idx, i in enumerate(range(0, self._chunk_length(text), step)):
chunks.append(" ".join(words[i : i + self.split_length]))
else:
for i in range(0, self._chunk_length(text), step):
Expand Down
29 changes: 17 additions & 12 deletions test/components/preprocessors/test_recursive_splitter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest
from pytest import LogCaptureFixture

Expand Down Expand Up @@ -401,11 +403,12 @@ def test_run_split_document_with_overlap_character_unit():

def test_run_separator_exists_but_split_length_too_small_fall_back_to_character_chunking():
splitter = RecursiveDocumentSplitter(separators=[" "], split_length=2, split_unit="char")
doc = Document(content="This is some text. This is some more text.")
doc = Document(content="This is some text")
result = splitter.run(documents=[doc])
assert len(result["documents"]) == 21
assert len(result["documents"]) == 10
for doc in result["documents"]:
assert len(doc.content) == 2
if re.escape(doc.content) not in ["\ "]:
assert len(doc.content) == 2


def test_run_fallback_to_character_chunking_by_default_length_too_short():
Expand Down Expand Up @@ -475,7 +478,7 @@ def test_run_split_by_dot_count_page_breaks_word_unit() -> None:

documents = document_splitter.run(documents=[Document(content=text)])["documents"]

assert len(documents) == 7
assert len(documents) == 8
assert documents[0].content == "Sentence on page 1."
assert documents[0].meta["page_number"] == 1
assert documents[0].meta["split_id"] == 0
Expand Down Expand Up @@ -506,11 +509,16 @@ def test_run_split_by_dot_count_page_breaks_word_unit() -> None:
assert documents[5].meta["split_id"] == 5
assert documents[5].meta["split_idx_start"] == text.index(documents[5].content)

assert documents[6].content == "\f\f Sentence on page 5."
assert documents[6].content == "\f\f Sentence on page"
assert documents[6].meta["page_number"] == 5
assert documents[6].meta["split_id"] == 6
assert documents[6].meta["split_idx_start"] == text.index(documents[6].content)

assert documents[7].content == " 5."
assert documents[7].meta["page_number"] == 5
assert documents[7].meta["split_id"] == 7
assert documents[7].meta["split_idx_start"] == text.index(documents[7].content)


def test_run_split_by_word_count_page_breaks_word_unit():
splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=0, separators=[" "], split_unit="word")
Expand Down Expand Up @@ -687,38 +695,35 @@ def test_run_split_by_sentence_tokenizer_document_and_overlap_word_unit_no_overl
chunks = splitter.run([Document(content=text)])["documents"]
assert len(chunks) == 3
assert chunks[0].content == "This is sentence one."
assert chunks[1].content == "This is sentence two."
assert chunks[2].content == "This is sentence three."
assert chunks[1].content == " This is sentence two."
assert chunks[2].content == " This is sentence three."


def test_run_split_by_dot_and_overlap_1_word_unit():
splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=1, separators=["."], split_unit="word")
text = "This is sentence one. This is sentence two. This is sentence three. This is sentence four."
chunks = splitter.run([Document(content=text)])["documents"]
assert len(chunks) == 6
assert len(chunks) == 5
assert chunks[0].content == "This is sentence one."
assert chunks[1].content == "one. This is sentence"
assert chunks[2].content == "sentence two. This is"
assert chunks[3].content == "is sentence three. This"
assert chunks[4].content == "This is sentence four."
assert chunks[5].content == "four."


def test_run_trigger_dealing_with_remaining_word_larger_than_split_length():
splitter = RecursiveDocumentSplitter(split_length=3, split_overlap=2, separators=["."], split_unit="word")
text = """A simple sentence1. A bright sentence2. A clever sentence3"""
doc = Document(content=text)
chunks = splitter.run([doc])["documents"]
assert len(chunks) == 9
assert len(chunks) == 7
assert chunks[0].content == "A simple sentence1."
assert chunks[1].content == "simple sentence1. A"
assert chunks[2].content == "sentence1. A bright"
assert chunks[3].content == "A bright sentence2."
assert chunks[4].content == "bright sentence2. A"
assert chunks[5].content == "sentence2. A clever"
assert chunks[6].content == "A clever sentence3"
assert chunks[7].content == "clever sentence3"
assert chunks[8].content == "sentence3"


def test_run_trigger_dealing_with_remaining_char_larger_than_split_length():
Expand Down

0 comments on commit e1464eb

Please sign in to comment.