Skip to content

Commit

Permalink
Fix equality of bound expressions (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko authored Oct 22, 2023
1 parent b6c1b02 commit c46e4bf
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
8 changes: 4 additions & 4 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(self, term: BoundTerm[L]):

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundPredicate class."""
if isinstance(other, BoundPredicate):
if isinstance(other, self.__class__):
return self.term == other.term
return False

Expand Down Expand Up @@ -567,7 +567,7 @@ def __repr__(self) -> str:

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundSetPredicate class."""
return self.term == other.term and self.literals == other.literals if isinstance(other, BoundSetPredicate) else False
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False

def __getnewargs__(self) -> Tuple[BoundTerm[L], Set[Literal[L]]]:
"""Pickle the BoundSetPredicate class."""
Expand Down Expand Up @@ -595,7 +595,7 @@ def __invert__(self) -> BoundNotIn[L]:

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundIn class."""
return self.term == other.term and self.literals == other.literals if isinstance(other, BoundIn) else False
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False

@property
def as_unbound(self) -> Type[In[L]]:
Expand Down Expand Up @@ -725,7 +725,7 @@ def __init__(self, term: BoundTerm[L], literal: Literal[L]): # pylint: disable=

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the BoundLiteralPredicate class."""
if isinstance(other, BoundLiteralPredicate):
if isinstance(other, self.__class__):
return self.term == other.term and self.literal == other.literal
return False

Expand Down
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from pyiceberg import schema
from pyiceberg.catalog import Catalog
from pyiceberg.catalog.noop import NoopCatalog
from pyiceberg.expressions import BoundReference
from pyiceberg.io import (
GCS_ENDPOINT,
GCS_PROJECT_ID,
Expand All @@ -69,7 +70,7 @@
)
from pyiceberg.io.fsspec import FsspecFileIO
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.schema import Schema
from pyiceberg.schema import Accessor, Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV2
Expand Down Expand Up @@ -1659,3 +1660,8 @@ def table(example_table_metadata_v2: Dict[str, Any]) -> Table:
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))
9 changes: 9 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,15 @@ def test_above_long_bounds_greater_than_or_equal(
assert GreaterThanOrEqual[int]("a", below_long_min).bind(long_schema) is AlwaysTrue()


def test_eq_bound_expression(bound_reference_str: BoundReference[str]) -> None:
assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) != BoundGreaterThanOrEqual(
term=bound_reference_str, literal=literal('a')
)
assert BoundEqualTo(term=bound_reference_str, literal=literal('a')) == BoundEqualTo(
term=bound_reference_str, literal=literal('a')
)


# __ __ ___
# | \/ |_ _| _ \_ _
# | |\/| | || | _/ || |
Expand Down
5 changes: 0 additions & 5 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,6 @@ def test_datetime_transform_repr(transform: TimeTransform[Any], transform_repr:
assert repr(transform) == transform_repr


@pytest.fixture
def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture
def bound_reference_date() -> BoundReference[int]:
return BoundReference(field=NestedField(1, "field", DateType(), required=False), accessor=Accessor(position=0, inner=None))
Expand Down

0 comments on commit c46e4bf

Please sign in to comment.