From ad97b1bbd09e5578f9648cb82b94a55d1b064843 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Sun, 22 Oct 2023 10:28:50 +0200 Subject: [PATCH] Fix equality of bound expressions --- pyiceberg/expressions/__init__.py | 8 ++++---- tests/conftest.py | 8 +++++++- tests/expressions/test_expressions.py | 9 +++++++++ tests/test_transforms.py | 5 ----- 4 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 2dd2fee1bd..d627efaad6 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -346,7 +346,7 @@ def __init__(self, term: BoundTerm[L]): def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundPredicate class.""" - if isinstance(other, BoundPredicate): + if isinstance(other, self.__class__): return self.term == other.term return False @@ -567,7 +567,7 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundSetPredicate class.""" - return self.term == other.term and self.literals == other.literals if isinstance(other, BoundSetPredicate) else False + return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False def __getnewargs__(self) -> Tuple[BoundTerm[L], Set[Literal[L]]]: """Pickle the BoundSetPredicate class.""" @@ -595,7 +595,7 @@ def __invert__(self) -> BoundNotIn[L]: def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundIn class.""" - return self.term == other.term and self.literals == other.literals if isinstance(other, BoundIn) else False + return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False @property def as_unbound(self) -> Type[In[L]]: @@ -725,7 +725,7 @@ def __init__(self, term: BoundTerm[L], literal: Literal[L]): # pylint: disable= def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the BoundLiteralPredicate class.""" - if isinstance(other, BoundLiteralPredicate): + if isinstance(other, self.__class__): return self.term == other.term and self.literal == other.literal return False diff --git a/tests/conftest.py b/tests/conftest.py index ed7f1caa21..79c01dc747 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,6 +57,7 @@ from pyiceberg import schema from pyiceberg.catalog import Catalog from pyiceberg.catalog.noop import NoopCatalog +from pyiceberg.expressions import BoundReference from pyiceberg.io import ( GCS_ENDPOINT, GCS_PROJECT_ID, @@ -69,7 +70,7 @@ ) from pyiceberg.io.fsspec import FsspecFileIO from pyiceberg.manifest import DataFile, FileFormat -from pyiceberg.schema import Schema +from pyiceberg.schema import Accessor, Schema from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table from pyiceberg.table.metadata import TableMetadataV2 @@ -1659,3 +1660,8 @@ def table(example_table_metadata_v2: Dict[str, Any]) -> Table: io=load_file_io(), catalog=NoopCatalog("NoopCatalog"), ) + + +@pytest.fixture +def bound_reference_str() -> BoundReference[str]: + return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None)) diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index bd3a14165e..bce23bf39c 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -1149,6 +1149,15 @@ def test_above_long_bounds_greater_than_or_equal( assert GreaterThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysTrue() +def test_eq_bound_expression(bound_reference_str: BoundReference[str]) -> None: + assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) != BoundGreaterThanOrEqual( + term=bound_reference_str, literal=literal('a') + ) + assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) == BoundEqualTo( + term=bound_reference_str, literal=literal('a') + ) + + # __ __ ___ # | \/ |_ _| _ \_ _ # | |\/| | || | _/ || | diff --git a/tests/test_transforms.py b/tests/test_transforms.py index d8a2151752..797bb48112 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -559,11 +559,6 @@ def test_datetime_transform_repr(transform: TimeTransform[Any], transform_repr: assert repr(transform) == transform_repr -@pytest.fixture -def bound_reference_str() -> BoundReference[str]: - return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None)) - - @pytest.fixture def bound_reference_date() -> BoundReference[int]: return BoundReference(field=NestedField(1, "field", DateType(), required=False), accessor=Accessor(position=0, inner=None))