Skip to content

Commit

Permalink
fixing fallback whitespaces count to fixed word/char split based on s…
Browse files Browse the repository at this point in the history
…plit size
  • Loading branch information
davidsbatista committed Jan 10, 2025
1 parent e1464eb commit 3eb532c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
21 changes: 18 additions & 3 deletions haystack/components/preprocessors/recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,24 @@ def _fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "
step = self.split_length - self.split_overlap

if split_units == "word":
words = text.split(" ")
for idx, i in enumerate(range(0, self._chunk_length(text), step)):
chunks.append(" ".join(words[i : i + self.split_length]))
words = re.findall(r"\S+|\s+", text)
current_chunk = []
current_length = 0

for word in words:
if word != " ":
current_chunk.append(word)
current_length += 1
if current_length == step and current_chunk:
chunks.append("".join(current_chunk))
current_chunk = []
current_length = 0
else:
current_chunk.append(word)

if current_chunk:
chunks.append("".join(current_chunk))

else:
for i in range(0, self._chunk_length(text), step):
chunks.append(text[i : i + self.split_length])
Expand Down
7 changes: 7 additions & 0 deletions test/components/preprocessors/test_recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,13 @@ def test_run_trigger_dealing_with_remaining_char_larger_than_split_length():
def test_run_custom_split_by_dot_and_overlap_3_char_unit():
pass

document_splitter = RecursiveDocumentSplitter(separators=["."], split_length=4, split_overlap=0, split_unit="word")
text = "\x0c\x0c Sentence on page 5."
chunks = document_splitter._fall_back_to_fixed_chunking(text, split_units="word")
assert len(chunks) == 2
assert chunks[0] == "\x0c\x0c Sentence on page"
assert chunks[1] == " 5."


def test_run_serialization_in_pipeline():
pipeline = Pipeline()
Expand Down

0 comments on commit 3eb532c

Please sign in to comment.