Skip to content

Commit

Permalink
Fix literal predicate equality check (#94)
Browse files Browse the repository at this point in the history
* Fix literal predicate equality check

* Fix the tests

* Some more fixes

---------

Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
danielcweeks and Fokko authored Oct 24, 2023
1 parent 4616d03 commit 66be1eb
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
12 changes: 3 additions & 9 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def __init__(self, term: Union[str, UnboundTerm[Any]]):

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the UnboundPredicate class."""
return self.term == other.term if isinstance(other, UnboundPredicate) else False
return self.term == other.term if isinstance(other, self.__class__) else False

@abstractmethod
def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression:
Expand Down Expand Up @@ -531,7 +531,7 @@ def __repr__(self) -> str:

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the SetPredicate class."""
return self.term == other.term and self.literals == other.literals if isinstance(other, SetPredicate) else False
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False

def __getnewargs__(self) -> Tuple[UnboundTerm[L], Set[Literal[L]]]:
"""Pickle the SetPredicate class."""
Expand Down Expand Up @@ -664,12 +664,6 @@ def __invert__(self) -> In[L]:
"""Transform the Expression into its negated version."""
return In[L](self.term, self.literals)

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the NotIn class."""
if isinstance(other, NotIn):
return self.term == other.term and self.literals == other.literals
return False

@property
def as_bound(self) -> Type[BoundNotIn[L]]:
return BoundNotIn[L]
Expand Down Expand Up @@ -701,7 +695,7 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredi

def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the LiteralPredicate class."""
if isinstance(other, LiteralPredicate):
if isinstance(other, self.__class__):
return self.term == other.term and self.literal == other.literal
return False

Expand Down
4 changes: 2 additions & 2 deletions tests/expressions/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def test_greater_than() -> None:


def test_greater_than_or_equal() -> None:
assert GreaterThanOrEqual("foo", 5) == parser.parse("foo <= 5")
assert GreaterThanOrEqual("foo", "a") == parser.parse("'a' >= foo")
assert GreaterThanOrEqual("foo", 5) == parser.parse("foo >= 5")
assert GreaterThanOrEqual("foo", "a") == parser.parse("'a' <= foo")


def test_equal_to() -> None:
Expand Down
20 changes: 16 additions & 4 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,15 +825,27 @@ def test_projection_truncate_string_literal_eq(bound_reference_str: BoundReferen


def test_projection_truncate_string_literal_gt(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project("name", BoundGreaterThan(term=bound_reference_str, literal=literal("data"))) == EqualTo(
term="name", literal=literal("da")
)
assert TruncateTransform(2).project(
"name", BoundGreaterThan(term=bound_reference_str, literal=literal("data"))
) == GreaterThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_literal_gte(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project(
"name", BoundGreaterThanOrEqual(term=bound_reference_str, literal=literal("data"))
) == EqualTo(term="name", literal=literal("da"))
) == GreaterThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_literal_lt(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project(
"name", BoundLessThan(term=bound_reference_str, literal=literal("data"))
) == LessThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_literal_lte(bound_reference_str: BoundReference[str]) -> None:
assert TruncateTransform(2).project(
"name", BoundLessThanOrEqual(term=bound_reference_str, literal=literal("data"))
) == LessThanOrEqual(term="name", literal=literal("da"))


def test_projection_truncate_string_set_same_result(bound_reference_str: BoundReference[str]) -> None:
Expand Down

0 comments on commit 66be1eb

Please sign in to comment.