From e6059e632ea4b0f68443920da9f33c6f3c3e6c4c Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:00:52 +0100 Subject: [PATCH 1/3] fix: truncate ByteStream string representation (#8673) * fix: truncate ByteStream string representation * add reno * better reno * add test * Update test_byte_stream.py * apply feedback * update reno --- haystack/dataclasses/byte_stream.py | 14 +++++++++++++- .../notes/fix-bytestream-str-8dd6d5e9a87f6aa4.yaml | 4 ++++ test/dataclasses/test_byte_stream.py | 9 +++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/fix-bytestream-str-8dd6d5e9a87f6aa4.yaml diff --git a/haystack/dataclasses/byte_stream.py b/haystack/dataclasses/byte_stream.py index 72a2648199..34b66add84 100644 --- a/haystack/dataclasses/byte_stream.py +++ b/haystack/dataclasses/byte_stream.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional -@dataclass +@dataclass(repr=False) class ByteStream: """ Base data class representing a binary object in the Haystack API. @@ -63,3 +63,15 @@ def to_string(self, encoding: str = "utf-8") -> str: :raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding. """ return self.data.decode(encoding) + + def __repr__(self) -> str: + """ + Return a string representation of the ByteStream, truncating the data to 100 bytes. + """ + fields = [] + truncated_data = self.data[:100] + b"..." if len(self.data) > 100 else self.data + fields.append(f"data={truncated_data!r}") + fields.append(f"meta={self.meta!r}") + fields.append(f"mime_type={self.mime_type!r}") + fields_str = ", ".join(fields) + return f"{self.__class__.__name__}({fields_str})" diff --git a/releasenotes/notes/fix-bytestream-str-8dd6d5e9a87f6aa4.yaml b/releasenotes/notes/fix-bytestream-str-8dd6d5e9a87f6aa4.yaml new file mode 100644 index 0000000000..7c3d4a429b --- /dev/null +++ b/releasenotes/notes/fix-bytestream-str-8dd6d5e9a87f6aa4.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + ByteStream now truncates the data to 100 bytes in the string representation to avoid excessive log output. diff --git a/test/dataclasses/test_byte_stream.py b/test/dataclasses/test_byte_stream.py index 574f1671bf..1858aad83c 100644 --- a/test/dataclasses/test_byte_stream.py +++ b/test/dataclasses/test_byte_stream.py @@ -71,3 +71,12 @@ def test_to_file(tmp_path, request): ByteStream(test_str.encode()).to_file(test_path) with open(test_path, "rb") as fd: assert fd.read().decode() == test_str + + +def test_str_truncation(): + test_str = "1234567890" * 100 + b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"}) + string_repr = str(b) + assert len(string_repr) < 200 + assert "text/plain" in string_repr + assert "foo" in string_repr From 5539f6c33ffda42bd5eb3155630e84923567b1eb Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 8 Jan 2025 11:28:00 +0100 Subject: [PATCH 2/3] refactor: improve serialization/deserialization of callables (to handle class methods and static methods) (#8683) * progress * refinements * tidy up * release note --- haystack/utils/callable_serialization.py | 58 ++++++++++++----- ...rove-callables-serde-6aa1e23408063247.yaml | 6 ++ .../chat/test_hugging_face_local.py | 3 +- .../components/generators/chat/test_openai.py | 33 +--------- test/components/generators/test_openai.py | 22 ------- .../preprocessors/test_document_splitter.py | 3 - test/utils/test_callable_serialization.py | 63 +++++++++++++++++-- 7 files changed, 112 insertions(+), 76 deletions(-) create mode 100644 releasenotes/notes/improve-callables-serde-6aa1e23408063247.yaml diff --git a/haystack/utils/callable_serialization.py b/haystack/utils/callable_serialization.py index 3c6003135e..3e4f947e8c 100644 --- a/haystack/utils/callable_serialization.py +++ b/haystack/utils/callable_serialization.py @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from typing import Callable, Optional +from typing import Callable -from haystack import DeserializationError +from haystack.core.errors import DeserializationError, SerializationError from haystack.utils.type_serialization import thread_safe_import @@ -16,17 +16,33 @@ def serialize_callable(callable_handle: Callable) -> str: :param callable_handle: The callable to serialize :return: The full path of the callable """ - module = inspect.getmodule(callable_handle) + try: + full_arg_spec = inspect.getfullargspec(callable_handle) + is_instance_method = bool(full_arg_spec.args and full_arg_spec.args[0] == "self") + except TypeError: + is_instance_method = False + if is_instance_method: + raise SerializationError("Serialization of instance methods is not supported.") + + # __qualname__ contains the fully qualified path we need for classmethods and staticmethods + qualname = getattr(callable_handle, "__qualname__", "") + if "" in qualname: + raise SerializationError("Serialization of lambdas is not supported.") + if "" in qualname: + raise SerializationError("Serialization of nested functions is not supported.") + + name = qualname or callable_handle.__name__ # Get the full package path of the function + module = inspect.getmodule(callable_handle) if module is not None: - full_path = f"{module.__name__}.{callable_handle.__name__}" + full_path = f"{module.__name__}.{name}" else: - full_path = callable_handle.__name__ + full_path = name return full_path -def deserialize_callable(callable_handle: str) -> Optional[Callable]: +def deserialize_callable(callable_handle: str) -> Callable: """ Deserializes a callable given its full import path as a string. @@ -34,14 +50,26 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]: :return: The callable :raises DeserializationError: If the callable cannot be found """ - parts = callable_handle.split(".") - module_name = ".".join(parts[:-1]) - function_name = parts[-1] + module_name, *attribute_chain = callable_handle.split(".") + try: - module = thread_safe_import(module_name) + current = thread_safe_import(module_name) except Exception as e: - raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e - deserialized_callable = getattr(module, function_name, None) - if not deserialized_callable: - raise DeserializationError(f"Could not locate the callable: {function_name}") - return deserialized_callable + raise DeserializationError(f"Could not locate the module: {module_name}") from e + + for attr in attribute_chain: + try: + attr_value = getattr(current, attr) + except AttributeError as e: + raise DeserializationError(f"Could not find attribute '{attr}' in {current.__name__}") from e + + # when the attribute is a classmethod, we need the underlying function + if isinstance(attr_value, (classmethod, staticmethod)): + attr_value = attr_value.__func__ + + current = attr_value + + if not callable(current): + raise DeserializationError(f"The final attribute is not callable: {current}") + + return current diff --git a/releasenotes/notes/improve-callables-serde-6aa1e23408063247.yaml b/releasenotes/notes/improve-callables-serde-6aa1e23408063247.yaml new file mode 100644 index 0000000000..54b0783e3d --- /dev/null +++ b/releasenotes/notes/improve-callables-serde-6aa1e23408063247.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Improve serialization and deserialization of callables. + We now allow serialization of classmethods and staticmethods + and explicitly prohibit serialization of instance methods, lambdas, and nested functions. diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 9b01acb134..c953404912 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -142,7 +142,7 @@ def test_to_dict(self, model_info_mock): token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"n": 5}, stop_words=["stop", "words"], - streaming_callback=lambda x: x, + streaming_callback=streaming_callback_handler, chat_template="irrelevant", ) @@ -155,6 +155,7 @@ def test_to_dict(self, model_info_mock): assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf" assert "token" not in init_params["huggingface_pipeline_kwargs"] assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} + assert init_params["streaming_callback"] == "chat.test_hugging_face_local.streaming_callback_handler" assert init_params["chat_template"] == "irrelevant" def test_from_dict(self, model_info_mock): diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 243eb36c89..677dfa812b 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -1,13 +1,12 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest -from typing import Iterator + import logging import os -import json from datetime import datetime from openai import OpenAIError @@ -15,12 +14,11 @@ from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat import chat_completion_chunk -from openai import Stream from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret -from haystack.dataclasses import ChatMessage, Tool, ToolCall, ChatRole, TextContent +from haystack.dataclasses import ChatMessage, Tool, ToolCall from haystack.components.generators.chat.openai import OpenAIChatGenerator @@ -212,31 +210,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - component = OpenAIChatGenerator( - model="gpt-4o-mini", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", - "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gpt-4o-mini", - "organization": None, - "api_base_url": "test-base-url", - "max_retries": None, - "timeout": None, - "streaming_callback": "chat.test_openai.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - "tools": None, - "tools_strict": False, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") data = { diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index 32628f7c45..e1d865c95f 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -90,28 +90,6 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } - def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - component = OpenAIGenerator( - model="gpt-4o-mini", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, - ) - data = component.to_dict() - assert data == { - "type": "haystack.components.generators.openai.OpenAIGenerator", - "init_parameters": { - "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, - "model": "gpt-4o-mini", - "system_prompt": None, - "organization": None, - "api_base_url": "test-base-url", - "streaming_callback": "test_openai.", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, - }, - } - def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") data = { diff --git a/test/components/preprocessors/test_document_splitter.py b/test/components/preprocessors/test_document_splitter.py index 094c17eeea..f9096239f2 100644 --- a/test/components/preprocessors/test_document_splitter.py +++ b/test/components/preprocessors/test_document_splitter.py @@ -467,9 +467,6 @@ def test_from_dict_with_splitting_function(self): Test the from_dict class method of the DocumentSplitter class when a custom splitting function is provided. """ - def custom_split(text): - return text.split(".") - data = { "type": "haystack.components.preprocessors.document_splitter.DocumentSplitter", "init_parameters": {"split_by": "function", "splitting_function": serialize_callable(custom_split)}, diff --git a/test/utils/test_callable_serialization.py b/test/utils/test_callable_serialization.py index 941aa14cdf..4f75ddd0ad 100644 --- a/test/utils/test_callable_serialization.py +++ b/test/utils/test_callable_serialization.py @@ -3,8 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest import requests - -from haystack import DeserializationError +from haystack.core.errors import DeserializationError, SerializationError from haystack.components.generators.utils import print_streaming_chunk from haystack.utils import serialize_callable, deserialize_callable @@ -13,6 +12,19 @@ def some_random_callable_for_testing(some_ignored_arg: str): pass +class TestClass: + @classmethod + def class_method(cls): + pass + + @staticmethod + def static_method(): + pass + + def my_method(self): + pass + + def test_callable_serialization(): result = serialize_callable(some_random_callable_for_testing) assert result == "test_callable_serialization.some_random_callable_for_testing" @@ -28,6 +40,28 @@ def test_callable_serialization_non_local(): assert result == "requests.api.get" +def test_callable_serialization_instance_methods_fail(): + with pytest.raises(SerializationError): + serialize_callable(TestClass.my_method) + + instance = TestClass() + with pytest.raises(SerializationError): + serialize_callable(instance.my_method) + + +def test_lambda_serialization_fail(): + with pytest.raises(SerializationError): + serialize_callable(lambda x: x) + + +def test_nested_function_serialization_fail(): + def my_fun(): + pass + + with pytest.raises(SerializationError): + serialize_callable(my_fun) + + def test_callable_deserialization(): result = serialize_callable(some_random_callable_for_testing) fn = deserialize_callable(result) @@ -40,8 +74,27 @@ def test_callable_deserialization_non_local(): assert fn is requests.api.get -def test_callable_deserialization_error(): +def test_classmethod_serialization_deserialization(): + result = serialize_callable(TestClass.class_method) + fn = deserialize_callable(result) + assert fn == TestClass.class_method + + +def test_staticmethod_serialization_deserialization(): + result = serialize_callable(TestClass.static_method) + fn = deserialize_callable(result) + assert fn == TestClass.static_method + + +def test_callable_deserialization_errors(): + # module does not exist with pytest.raises(DeserializationError): - deserialize_callable("this.is.not.a.valid.module") + deserialize_callable("nonexistent_module.function") + + # function does not exist + with pytest.raises(DeserializationError): + deserialize_callable("os.nonexistent_function") + + # attribute is not callable with pytest.raises(DeserializationError): - deserialize_callable("sys.foobar") + deserialize_callable("os.name") From bc30105fbcdc3b4316010d9b5a2c48d21751f740 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 8 Jan 2025 15:58:52 +0100 Subject: [PATCH 3/3] test: reorganize docstore test suite to isolate dataframe tests (#8684) * reorganize docstore test suite to isolate dataframe tests * improve docstring * include FilterDocumentsTestWithDataframe in InMemoryDocumentStore tests --- haystack/testing/document_store.py | 249 ++++++++++-------- ...sting-for-dataframes-3825910ade718d51.yaml | 5 + test/document_stores/test_in_memory.py | 4 +- 3 files changed, 152 insertions(+), 106 deletions(-) create mode 100644 releasenotes/notes/reorganize-docstore-testing-for-dataframes-3825910ade718d51.yaml diff --git a/haystack/testing/document_store.py b/haystack/testing/document_store.py index 6d2bda2804..a86823ebf5 100644 --- a/haystack/testing/document_store.py +++ b/haystack/testing/document_store.py @@ -174,74 +174,86 @@ def test_delete_documents_non_existing_document(self, document_store: DocumentSt assert document_store.count_documents() == 1 -class FilterableDocsFixtureMixin: +def create_filterable_docs(include_dataframe_docs: bool = False) -> List[Document]: """ - Mixin class that adds a filterable_docs() fixture to a test class. + Create a list of filterable documents to be used in the filterable_docs and filterable_docs_with_dataframe fixtures. """ - @pytest.fixture - def filterable_docs(self) -> List[Document]: - """Fixture that returns a list of Documents that can be used to test filtering.""" - documents = [] - for i in range(3): - documents.append( - Document( - content=f"A Foo Document {i}", - meta={ - "name": f"name_{i}", - "page": "100", - "chapter": "intro", - "number": 2, - "date": "1969-07-21T20:17:40", - }, - embedding=_random_embeddings(768), - ) + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40", + }, + embedding=_random_embeddings(768), ) - documents.append( - Document( - content=f"A Bar Document {i}", - meta={ - "name": f"name_{i}", - "page": "123", - "chapter": "abstract", - "number": -2, - "date": "1972-12-11T19:54:58", - }, - embedding=_random_embeddings(768), - ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58", + }, + embedding=_random_embeddings(768), ) - documents.append( - Document( - content=f"A Foobar Document {i}", - meta={ - "name": f"name_{i}", - "page": "90", - "chapter": "conclusion", - "number": -10, - "date": "1989-11-09T17:53:00", - }, - embedding=_random_embeddings(768), - ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00", + }, + embedding=_random_embeddings(768), ) - documents.append( - Document( - content=f"Document {i} without embedding", - meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, - ) + ) + documents.append( + Document( + content=f"Document {i} without embedding", + meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, ) + ) + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + + if include_dataframe_docs: + for i in range(3): documents.append(Document(dataframe=pd.DataFrame([i]), meta={"name": f"table_doc_{i}"})) - documents.append( - Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) - ) - documents.append( - Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) - ) - return documents + + return documents + + +class FilterableDocsFixtureMixin: + """ + Mixin class that adds a filterable_docs() fixture to a test class. + """ + + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """Fixture that returns a list of Documents that can be used to test filtering.""" + return create_filterable_docs(include_dataframe_docs=False) class FilterDocumentsTest(AssertDocumentsEqualMixin, FilterableDocsFixtureMixin): """ - Utility class to test a Document Store `filter_documents` method using different types of filters. + Utility class to test a Document Store `filter_documents` method using different types of filters. To use it create a custom test class and override the `document_store` fixture to return your Document Store. Example usage: @@ -270,16 +282,6 @@ def test_comparison_equal(self, document_store, filterable_docs): result = document_store.filter_documents(filters={"field": "meta.number", "operator": "==", "value": 100}) self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") == 100]) - def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): - """Test filter_documents() with == comparator and dataframe""" - document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"field": "dataframe", "operator": "==", "value": pd.DataFrame([1])} - ) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if d.dataframe is not None and d.dataframe.equals(pd.DataFrame([1]))] - ) - def test_comparison_equal_with_none(self, document_store, filterable_docs): """Test filter_documents() with == comparator and None""" document_store.write_documents(filterable_docs) @@ -293,16 +295,6 @@ def test_comparison_not_equal(self, document_store, filterable_docs): result = document_store.filter_documents({"field": "meta.number", "operator": "!=", "value": 100}) self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") != 100]) - def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): - """Test filter_documents() with != comparator and dataframe""" - document_store.write_documents(filterable_docs) - result = document_store.filter_documents( - filters={"field": "dataframe", "operator": "!=", "value": pd.DataFrame([1])} - ) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if d.dataframe is None or not d.dataframe.equals(pd.DataFrame([1]))] - ) - def test_comparison_not_equal_with_none(self, document_store, filterable_docs): """Test filter_documents() with != comparator and None""" document_store.write_documents(filterable_docs) @@ -340,12 +332,6 @@ def test_comparison_greater_than_with_string(self, document_store, filterable_do with pytest.raises(FilterError): document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": "1"}) - def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): - """Test filter_documents() with > comparator and dataframe""" - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"field": "dataframe", "operator": ">", "value": pd.DataFrame([1])}) - def test_comparison_greater_than_with_list(self, document_store, filterable_docs): """Test filter_documents() with > comparator and list""" document_store.write_documents(filterable_docs) @@ -389,14 +375,6 @@ def test_comparison_greater_than_equal_with_string(self, document_store, filtera with pytest.raises(FilterError): document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": "1"}) - def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): - """Test filter_documents() with >= comparator and dataframe""" - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents( - filters={"field": "dataframe", "operator": ">=", "value": pd.DataFrame([1])} - ) - def test_comparison_greater_than_equal_with_list(self, document_store, filterable_docs): """Test filter_documents() with >= comparator and list""" document_store.write_documents(filterable_docs) @@ -440,12 +418,6 @@ def test_comparison_less_than_with_string(self, document_store, filterable_docs) with pytest.raises(FilterError): document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": "1"}) - def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): - """Test filter_documents() with < comparator and dataframe""" - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"field": "dataframe", "operator": "<", "value": pd.DataFrame([1])}) - def test_comparison_less_than_with_list(self, document_store, filterable_docs): """Test filter_documents() with < comparator and list""" document_store.write_documents(filterable_docs) @@ -489,14 +461,6 @@ def test_comparison_less_than_equal_with_string(self, document_store, filterable with pytest.raises(FilterError): document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": "1"}) - def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): - """Test filter_documents() with <= comparator and dataframe""" - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents( - filters={"field": "dataframe", "operator": "<=", "value": pd.DataFrame([1])} - ) - def test_comparison_less_than_equal_with_list(self, document_store, filterable_docs): """Test filter_documents() with <= comparator and list""" document_store.write_documents(filterable_docs) @@ -638,6 +602,83 @@ def test_missing_condition_value_key(self, document_store, filterable_docs): ) +class FilterableDocsFixtureMixinWithDataframe: + """ + Mixin class that adds a filterable_docs_with_dataframe() fixture to a test class, including dataframe documents. + """ + + @pytest.fixture + def filterable_docs_with_dataframe(self) -> List[Document]: + """Fixture that returns a list of Documents including dataframe documents.""" + documents = create_filterable_docs(include_dataframe_docs=True) + + return documents + + +class FilterDocumentsTestWithDataframe(AssertDocumentsEqualMixin, FilterableDocsFixtureMixinWithDataframe): + """ + Utility class to test a Document Store `filter_documents` method specifically for DataFrame documents. + """ + + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs_with_dataframe): + """Test filter_documents() with == comparator and dataframe""" + document_store.write_documents(filterable_docs_with_dataframe) + result = document_store.filter_documents( + filters={"field": "dataframe", "operator": "==", "value": pd.DataFrame([1])} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs_with_dataframe + if d.dataframe is not None and d.dataframe.equals(pd.DataFrame([1])) + ], + ) + + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs_with_dataframe): + """Test filter_documents() with != comparator and dataframe""" + document_store.write_documents(filterable_docs_with_dataframe) + result = document_store.filter_documents( + filters={"field": "dataframe", "operator": "!=", "value": pd.DataFrame([1])} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs_with_dataframe + if d.dataframe is None or not d.dataframe.equals(pd.DataFrame([1])) + ], + ) + + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs_with_dataframe): + """Test filter_documents() with > comparator and dataframe""" + document_store.write_documents(filterable_docs_with_dataframe) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "dataframe", "operator": ">", "value": pd.DataFrame([1])}) + + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs_with_dataframe): + """Test filter_documents() with >= comparator and dataframe""" + document_store.write_documents(filterable_docs_with_dataframe) + with pytest.raises(FilterError): + document_store.filter_documents( + filters={"field": "dataframe", "operator": ">=", "value": pd.DataFrame([1])} + ) + + def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs_with_dataframe): + """Test filter_documents() with < comparator and dataframe""" + document_store.write_documents(filterable_docs_with_dataframe) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "dataframe", "operator": "<", "value": pd.DataFrame([1])}) + + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs_with_dataframe): + """Test filter_documents() with <= comparator and dataframe""" + document_store.write_documents(filterable_docs_with_dataframe) + with pytest.raises(FilterError): + document_store.filter_documents( + filters={"field": "dataframe", "operator": "<=", "value": pd.DataFrame([1])} + ) + + class DocumentStoreBaseTests(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest): @pytest.fixture def document_store(self) -> DocumentStore: diff --git a/releasenotes/notes/reorganize-docstore-testing-for-dataframes-3825910ade718d51.yaml b/releasenotes/notes/reorganize-docstore-testing-for-dataframes-3825910ade718d51.yaml new file mode 100644 index 0000000000..935c3cbab0 --- /dev/null +++ b/releasenotes/notes/reorganize-docstore-testing-for-dataframes-3825910ade718d51.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Reorganized the document store test suite to isolate dataframe filter tests. + This change prepares for potential future deprecation of the Document class's dataframe field. diff --git a/test/document_stores/test_in_memory.py b/test/document_stores/test_in_memory.py index ba623eedb3..8c2a313d92 100644 --- a/test/document_stores/test_in_memory.py +++ b/test/document_stores/test_in_memory.py @@ -11,10 +11,10 @@ from haystack import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.testing.document_store import DocumentStoreBaseTests, FilterDocumentsTestWithDataframe -class TestMemoryDocumentStore(DocumentStoreBaseTests): # pylint: disable=R0904 +class TestMemoryDocumentStore(DocumentStoreBaseTests, FilterDocumentsTestWithDataframe): # pylint: disable=R0904 """ Test InMemoryDocumentStore's specific features """