diff --git a/haystack/components/preprocessors/recursive_splitter.py b/haystack/components/preprocessors/recursive_splitter.py index c517c7075c..b6e3dc4306 100644 --- a/haystack/components/preprocessors/recursive_splitter.py +++ b/haystack/components/preprocessors/recursive_splitter.py @@ -121,7 +121,7 @@ def _split_chunk(self, current_chunk: str) -> Tuple[str, str]: """ if self.split_units == "word": - words = current_chunk.split(" ") + words = current_chunk.split() current_chunk = " ".join(words[: self.split_length]) remaining_words = words[self.split_length :] return current_chunk, " ".join(remaining_words) @@ -201,12 +201,12 @@ def _apply_overlap(self, chunks: List[str]) -> List[str]: return overlapped_chunks - def _get_overlap(self, overlapped_chunks): + def _get_overlap(self, overlapped_chunks: List[str]) -> Tuple[str, str]: """Get the previous overlapped chunk instead of the original chunk.""" prev_chunk = overlapped_chunks[-1] overlap_start = max(0, self._chunk_length(prev_chunk) - self.split_overlap) if self.split_units == "word": - word_chunks = prev_chunk.split(" ") + word_chunks = prev_chunk.split() overlap = " ".join(word_chunks[overlap_start:]) else: overlap = prev_chunk[overlap_start:] @@ -214,32 +214,17 @@ def _get_overlap(self, overlapped_chunks): def _chunk_length(self, text: str) -> int: """ - Get the length of the chunk in words or characters. + Split the text by whitespace and count non-empty elements. - :param text: The text to be split into chunks. - :returns: - The length of the chunk in words or characters. + :param: The text to be split. + :return: The number of words in the text. """ + if self.split_units == "word": - return len(text.split(" ")) - else: - return len(text) - - # def _chunk_length(self, text: str) -> int: - # """ - # Split the text by whitespace and count non-empty elements - # Count newline and form feed characters - # - # :param text: - # :return: - # """ - # - # if self.split_units == "word": - # words = [word for word in text.split() if word] - # special_chars = text.count('\n') + text.count('\f') + text.count('\x0c') - # return len(words) + special_chars - # - # return len(text) + words = [word for word in text.split(" ") if word] + return len(words) + + return len(text) def _chunk_text(self, text: str) -> List[str]: """ @@ -299,10 +284,13 @@ def _chunk_text(self, text: str) -> List[str]: # recursively handle splits that are too large if self._chunk_length(split_text) > self.split_length: if curr_separator == self.separators[-1]: - # tried the last separator, can't split further, break the loop and fall back to - # word- or character-level chunking - return self.fall_back_to_fixed_chunking(text, self.split_units) - chunks.extend(self._chunk_text(split_text)) + # tried last separator, can't split further, do a fixed-split based on word/character + fall_back_chunks = self._fall_back_to_fixed_chunking(split_text, self.split_units) + chunks.extend(fall_back_chunks) + else: + chunks.extend(self._chunk_text(split_text)) + current_length += self._chunk_length(split_text) + else: current_chunk.append(split_text) current_length += self._chunk_length(split_text) @@ -320,9 +308,9 @@ def _chunk_text(self, text: str) -> List[str]: return chunks # if no separator worked, fall back to word- or character-level chunking - return self.fall_back_to_fixed_chunking(text, self.split_units) + return self._fall_back_to_fixed_chunking(text, self.split_units) - def fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char"]) -> List[str]: + def _fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char"]) -> List[str]: """ Fall back to a fixed chunking approach if no separator works for the text. @@ -336,7 +324,7 @@ def fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "c if split_units == "word": words = text.split(" ") - for i in range(0, self._chunk_length(text), step): + for idx, i in enumerate(range(0, self._chunk_length(text), step)): chunks.append(" ".join(words[i : i + self.split_length])) else: for i in range(0, self._chunk_length(text), step): diff --git a/test/components/preprocessors/test_recursive_splitter.py b/test/components/preprocessors/test_recursive_splitter.py index d3c86f417d..3ea73e09fa 100644 --- a/test/components/preprocessors/test_recursive_splitter.py +++ b/test/components/preprocessors/test_recursive_splitter.py @@ -1,3 +1,5 @@ +import re + import pytest from pytest import LogCaptureFixture @@ -401,11 +403,12 @@ def test_run_split_document_with_overlap_character_unit(): def test_run_separator_exists_but_split_length_too_small_fall_back_to_character_chunking(): splitter = RecursiveDocumentSplitter(separators=[" "], split_length=2, split_unit="char") - doc = Document(content="This is some text. This is some more text.") + doc = Document(content="This is some text") result = splitter.run(documents=[doc]) - assert len(result["documents"]) == 21 + assert len(result["documents"]) == 10 for doc in result["documents"]: - assert len(doc.content) == 2 + if re.escape(doc.content) not in ["\ "]: + assert len(doc.content) == 2 def test_run_fallback_to_character_chunking_by_default_length_too_short(): @@ -475,7 +478,7 @@ def test_run_split_by_dot_count_page_breaks_word_unit() -> None: documents = document_splitter.run(documents=[Document(content=text)])["documents"] - assert len(documents) == 7 + assert len(documents) == 8 assert documents[0].content == "Sentence on page 1." assert documents[0].meta["page_number"] == 1 assert documents[0].meta["split_id"] == 0 @@ -506,11 +509,16 @@ def test_run_split_by_dot_count_page_breaks_word_unit() -> None: assert documents[5].meta["split_id"] == 5 assert documents[5].meta["split_idx_start"] == text.index(documents[5].content) - assert documents[6].content == "\f\f Sentence on page 5." + assert documents[6].content == "\f\f Sentence on page" assert documents[6].meta["page_number"] == 5 assert documents[6].meta["split_id"] == 6 assert documents[6].meta["split_idx_start"] == text.index(documents[6].content) + assert documents[7].content == " 5." + assert documents[7].meta["page_number"] == 5 + assert documents[7].meta["split_id"] == 7 + assert documents[7].meta["split_idx_start"] == text.index(documents[7].content) + def test_run_split_by_word_count_page_breaks_word_unit(): splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=0, separators=[" "], split_unit="word") @@ -687,21 +695,20 @@ def test_run_split_by_sentence_tokenizer_document_and_overlap_word_unit_no_overl chunks = splitter.run([Document(content=text)])["documents"] assert len(chunks) == 3 assert chunks[0].content == "This is sentence one." - assert chunks[1].content == "This is sentence two." - assert chunks[2].content == "This is sentence three." + assert chunks[1].content == " This is sentence two." + assert chunks[2].content == " This is sentence three." def test_run_split_by_dot_and_overlap_1_word_unit(): splitter = RecursiveDocumentSplitter(split_length=4, split_overlap=1, separators=["."], split_unit="word") text = "This is sentence one. This is sentence two. This is sentence three. This is sentence four." chunks = splitter.run([Document(content=text)])["documents"] - assert len(chunks) == 6 + assert len(chunks) == 5 assert chunks[0].content == "This is sentence one." assert chunks[1].content == "one. This is sentence" assert chunks[2].content == "sentence two. This is" assert chunks[3].content == "is sentence three. This" assert chunks[4].content == "This is sentence four." - assert chunks[5].content == "four." def test_run_trigger_dealing_with_remaining_word_larger_than_split_length(): @@ -709,7 +716,7 @@ def test_run_trigger_dealing_with_remaining_word_larger_than_split_length(): text = """A simple sentence1. A bright sentence2. A clever sentence3""" doc = Document(content=text) chunks = splitter.run([doc])["documents"] - assert len(chunks) == 9 + assert len(chunks) == 7 assert chunks[0].content == "A simple sentence1." assert chunks[1].content == "simple sentence1. A" assert chunks[2].content == "sentence1. A bright" @@ -717,8 +724,6 @@ def test_run_trigger_dealing_with_remaining_word_larger_than_split_length(): assert chunks[4].content == "bright sentence2. A" assert chunks[5].content == "sentence2. A clever" assert chunks[6].content == "A clever sentence3" - assert chunks[7].content == "clever sentence3" - assert chunks[8].content == "sentence3" def test_run_trigger_dealing_with_remaining_char_larger_than_split_length():