From 731542e8fe9e170c0757276b1e6ed42a936325a4 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Thu, 28 Nov 2024 19:02:22 +0530 Subject: [PATCH 01/12] Create test_scan_count.py --- pyiceberg/table/test_scan_count.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 pyiceberg/table/test_scan_count.py diff --git a/pyiceberg/table/test_scan_count.py b/pyiceberg/table/test_scan_count.py new file mode 100644 index 0000000000..807d5644dd --- /dev/null +++ b/pyiceberg/table/test_scan_count.py @@ -0,0 +1,11 @@ +from pyiceberg.table import Table, TableMetadata, DataScan +from pyiceberg.catalog.sql import SqlCatalog + +def test_iceberg_count(): + table = _create_iceberg_metadata() + assert len(table.to_arrow()) == 2 + + +def test_iceberg_metadata_only_count(): + table = _create_iceberg_metadata() + assert table.count() == 2 From c6c971e26ce4a8987a23798ba1a3cf942b732212 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Thu, 28 Nov 2024 19:10:25 +0530 Subject: [PATCH 02/12] moved test_scan_count.py to tests --- {pyiceberg => tests}/table/test_scan_count.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {pyiceberg => tests}/table/test_scan_count.py (100%) diff --git a/pyiceberg/table/test_scan_count.py b/tests/table/test_scan_count.py similarity index 100% rename from pyiceberg/table/test_scan_count.py rename to tests/table/test_scan_count.py From da18837049d3a932e006dae0557ec44023e192b6 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Thu, 28 Nov 2024 19:30:33 +0530 Subject: [PATCH 03/12] implemented count in data scan --- pyiceberg/table/__init__.py | 9 +++++++++ tests/table/test_scan_count.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 3eb74eee1f..ea3bba3a1f 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1227,6 +1227,9 @@ def filter(self: S, expr: Union[str, BooleanExpression]) -> S: def with_case_sensitive(self: S, case_sensitive: bool = True) -> S: return self.update(case_sensitive=case_sensitive) + @abstractmethod + def count(self) -> int : ... + class ScanTask(ABC): pass @@ -1493,6 +1496,12 @@ def to_ray(self) -> ray.data.dataset.Dataset: return ray.data.from_arrow(self.to_arrow()) + def count(self) -> int: + res = 0 + tasks = self.plan_files() + for task in tasks: + res += task.file.record_count + return res @dataclass(frozen=True) class WriteTask: diff --git a/tests/table/test_scan_count.py b/tests/table/test_scan_count.py index 807d5644dd..a4616f31d2 100644 --- a/tests/table/test_scan_count.py +++ b/tests/table/test_scan_count.py @@ -1,6 +1,10 @@ from pyiceberg.table import Table, TableMetadata, DataScan from pyiceberg.catalog.sql import SqlCatalog +def get_plan(table: Table) -> DataScan: + return table.scan() + + def test_iceberg_count(): table = _create_iceberg_metadata() assert len(table.to_arrow()) == 2 From 3104a2fa2e5fac8ba51707519216c1667eebd64c Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Thu, 28 Nov 2024 23:23:21 +0530 Subject: [PATCH 04/12] tested table scan count in test_sql catalog --- tests/catalog/test_sql.py | 52 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index fcefc597d2..18c11e3cca 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -1431,6 +1431,58 @@ def test_append_table(catalog: SqlCatalog, table_schema_simple: Schema, table_id assert df == table.scan().to_arrow() +@pytest.mark.parametrize( + "catalog", + [ + lazy_fixture("catalog_memory"), + lazy_fixture("catalog_sqlite"), + lazy_fixture("catalog_sqlite_without_rowcount"), + lazy_fixture("catalog_sqlite_fsspec"), + ], +) +@pytest.mark.parametrize( + "table_identifier", + [ + lazy_fixture("random_table_identifier"), + lazy_fixture("random_hierarchical_identifier"), + lazy_fixture("random_table_identifier_with_catalog"), + ], +) +def test_count_table(catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier) -> None: + table_identifier_nocatalog = catalog._identifier_to_tuple_without_catalog(table_identifier) + namespace = Catalog.namespace_from(table_identifier_nocatalog) + catalog.create_namespace(namespace) + table = catalog.create_table(table_identifier, table_schema_simple) + + df = pa.Table.from_pydict( + { + "foo": ["a"], + "bar": [1], + "baz": [True], + }, + schema=schema_to_pyarrow(table_schema_simple), + ) + + table.append(df) + + # new snapshot is written in APPEND mode + assert len(table.metadata.snapshots) == 1 + assert table.metadata.snapshots[0].snapshot_id == table.metadata.current_snapshot_id + assert table.metadata.snapshots[0].parent_snapshot_id is None + assert table.metadata.snapshots[0].sequence_number == 1 + assert table.metadata.snapshots[0].summary is not None + assert table.metadata.snapshots[0].summary.operation == Operation.APPEND + assert table.metadata.snapshots[0].summary["added-data-files"] == "1" + assert table.metadata.snapshots[0].summary["added-records"] == "1" + assert table.metadata.snapshots[0].summary["total-data-files"] == "1" + assert table.metadata.snapshots[0].summary["total-records"] == "1" + assert len(table.metadata.metadata_log) == 1 + + # read back the data + assert df == table.scan().to_arrow() + assert len(table.scan().to_arrow()) == table.scan().count() + + @pytest.mark.parametrize( "catalog", [ From c2740eac2d2ffbe06caebf371ae99b218409fe5b Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Thu, 28 Nov 2024 23:23:59 +0530 Subject: [PATCH 05/12] refactoring --- tests/table/test_scan_count.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 tests/table/test_scan_count.py diff --git a/tests/table/test_scan_count.py b/tests/table/test_scan_count.py deleted file mode 100644 index a4616f31d2..0000000000 --- a/tests/table/test_scan_count.py +++ /dev/null @@ -1,15 +0,0 @@ -from pyiceberg.table import Table, TableMetadata, DataScan -from pyiceberg.catalog.sql import SqlCatalog - -def get_plan(table: Table) -> DataScan: - return table.scan() - - -def test_iceberg_count(): - table = _create_iceberg_metadata() - assert len(table.to_arrow()) == 2 - - -def test_iceberg_metadata_only_count(): - table = _create_iceberg_metadata() - assert table.count() == 2 From 90bca8496ce00dfb363d9037a8fe9af54687d5bc Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Thu, 28 Nov 2024 23:28:21 +0530 Subject: [PATCH 06/12] make lint --- pyiceberg/table/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ea3bba3a1f..98fea3c99d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1228,7 +1228,7 @@ def with_case_sensitive(self: S, case_sensitive: bool = True) -> S: return self.update(case_sensitive=case_sensitive) @abstractmethod - def count(self) -> int : ... + def count(self) -> int: ... class ScanTask(ABC): @@ -1503,6 +1503,7 @@ def count(self) -> int: res += task.file.record_count return res + @dataclass(frozen=True) class WriteTask: """Task with the parameters for writing a DataFile.""" From 091c0afd96a774db6250fe9379a8e9cfcff74677 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:46:22 +0530 Subject: [PATCH 07/12] implemeted residual_evaluator.py with tests --- pyiceberg/expressions/residual_evaluator.py | 220 +++++++++++++ tests/expressions/test_residual_evaluator.py | 306 +++++++++++++++++++ 2 files changed, 526 insertions(+) create mode 100644 pyiceberg/expressions/residual_evaluator.py create mode 100644 tests/expressions/test_residual_evaluator.py diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py new file mode 100644 index 0000000000..4aac91e3a6 --- /dev/null +++ b/pyiceberg/expressions/residual_evaluator.py @@ -0,0 +1,220 @@ +from abc import ABC +from pyiceberg.expressions.visitors import ( + BoundBooleanExpressionVisitor, + BooleanExpression, + UnboundPredicate, + BoundPredicate, + visit, + BoundTerm, + AlwaysFalse, + AlwaysTrue +) +from pyiceberg.expressions.literals import Literal +from pyiceberg.expressions import ( + And, + Or +) +from pyiceberg.types import L +from pyiceberg.partitioning import PartitionSpec +from pyiceberg.schema import Schema +from typing import Any, List, Set +from pyiceberg.typedef import Record + + +class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): + schema: Schema + spec: PartitionSpec + case_sensitive: bool + + def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, expr: BooleanExpression): + self.schema = schema + self.spec = spec + self.case_sensitive = case_sensitive + self.expr = expr + + + def eval(self, partition_data: Record): + self.struct = partition_data + return visit(self.expr, visitor=self) + + + def visit_true(self) -> BooleanExpression: + return AlwaysTrue() + + def visit_false(self) -> BooleanExpression: + return AlwaysFalse() + + def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: + return Not(child_result) + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return And(left_result, right_result) + + def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return Or(left_result, right_result) + + + def visit_is_null(self, term: BoundTerm[L]) -> bool: + return term.eval(self.struct) is None + + def visit_not_null(self, term: BoundTerm[L]) -> bool: + return term.eval(self.struct) is not None + + def visit_is_nan(self, term: BoundTerm[L]) -> bool: + val = term.eval(self.struct) + if val is None: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_nan(self, term: BoundTerm[L]) -> bool: + val = term.eval(self.struct) + if val is not None: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) < literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) <= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) > literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) >= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) == literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + if term.eval(self.struct) != literal.value: + return self.visit_true() + else: + return self.visit_false() + + + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + if term.eval(self.struct) in literals: + return self.visit_true() + else: + return self.visit_false() + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + if term.eval(self.struct) not in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + eval_res = term.eval(self.struct) + return eval_res is not None and str(eval_res).startswith(str(literal.value)) + + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + return not self.visit_starts_with(term, literal) + + def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> List[str]: + """ + called from eval + input + """ + parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) + if parts == []: + return predicate + + from pyiceberg.types import StructType + def struct_to_schema(struct: StructType) -> Schema: + return Schema(*[f for f in struct.fields]) + + for part in parts: + + strict_projection = part.transform.strict_project(part.name, predicate) + strict_result = None + + if strict_projection is not None: + bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + assert isinstance(bound, BoundPredicate) + if isinstance(bound, BoundPredicate): + strict_result = super().visit_bound_predicate(bound) + else: + strict_result = bound + + if strict_result is not None and isinstance(strict_result, AlwaysTrue): + return AlwaysTrue() + + inclusive_projection = part.transform.project(part.name, predicate) + inclusive_result = None + if inclusive_projection is not None: + bound_inclusive = inclusive_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound_inclusive, BoundPredicate): + # using predicate method specific to inclusive + inclusive_result = super().visit_bound_predicate(bound_inclusive) + else: + # if the result is not a predicate, then it must be a constant like alwaysTrue or + # alwaysFalse + inclusive_result = bound_inclusive + if inclusive_result is not None and isinstance(inclusive_result, AlwaysFalse): + return AlwaysFalse() + + return predicate + + def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + bound = predicate.bind(self.schema, case_sensitive=True) + + if isinstance(bound, BoundPredicate): + bound_residual = self.visit_bound_predicate(predicate=bound) + # if isinstance(bound_residual, BooleanExpression): + if bound_residual not in (AlwaysFalse(), AlwaysTrue()): + # replace inclusive original unbound predicate + return predicate + + # use the non-predicate residual (e.g. alwaysTrue) + return bound_residual + + # if binding didn't result in a Predicate, return the expression + return bound + + + + + +class ResidualEvaluator(ResidualVisitor): + def residual_for(self, partition_data): + return self.eval(partition_data) + +class UnpartitionedResidualEvaluator(ResidualEvaluator): + + def __init__(self, schema: Schema,expr: BooleanExpression): + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) + self.expr = expr + + def residual_for(self, partition_data): + return self.expr + + +def residual_evaluator_of( + spec: PartitionSpec, + expr: BooleanExpression, + case_sensitive: bool, + schema: Schema +) -> ResidualEvaluator: + if len(spec.fields) != 0: + return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) + else: + return UnpartitionedResidualEvaluator(schema=schema,expr=expr) \ No newline at end of file diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py new file mode 100644 index 0000000000..f47cabac38 --- /dev/null +++ b/tests/expressions/test_residual_evaluator.py @@ -0,0 +1,306 @@ +import pytest +from pyiceberg.expressions import ( + AlwaysTrue, + EqualTo, + LessThan, + AlwaysFalse, + And, + Or, + GreaterThan, + GreaterThanOrEqual, + UnboundPredicate, + BoundPredicate, + BoundReference, + BooleanExpression, + BoundLessThan, + BoundGreaterThan, + NotNull, + IsNull, + In, + NotIn, + NotNaN, + IsNaN, + StartsWith, + NotStartsWith +) +from pyiceberg.expressions.residual_evaluator import residual_evaluator_of +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.transforms import IdentityTransform, DayTransform +from pyiceberg.typedef import Record +from pyiceberg.types import ( + IntegerType, + DoubleType, + FloatType, + NestedField, + StringType, + TimestampType +) +from pyiceberg.utils.datetime import timestamp_to_micros +from pyiceberg.expressions.literals import literal + + +def test_identity_transform_residual(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = Or( + Or( + And(EqualTo("dateint", 20170815), LessThan("hour", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + ), + And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)) + ) + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + # assert residual == True + assert isinstance(residual, UnboundPredicate) + assert residual.term.name == 'hour' + # assert residual.term.field.name == 'hour' + assert residual.literal.value == 12 + assert type(residual) == LessThan + + residual = res_eval.residual_for(Record(dateint=20170801)) + + assert isinstance(residual, UnboundPredicate) + assert residual.term.name == 'hour' + assert residual.literal.value == 11 + assert type(residual) == GreaterThan + + residual = res_eval.residual_for(Record(dateint=20170812)) + + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170817)) + + assert residual == AlwaysFalse() + + +def test_case_insensitive_identity_transform_residuals(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = Or( + Or( + And(EqualTo("DATEINT", 20170815), LessThan("HOUR", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + ), + And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)) + ) + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + + with pytest.raises(ValueError) as e: + residual = res_eval.residual_for(Record(dateint=20170815)) + assert "Could not find field with name DATEINT, case_sensitive=True" in str(e.value) + + +def test_unpartitioned_residuals(): + + + expressions = [ + AlwaysTrue(), + AlwaysFalse(), + LessThan("a", 5), + GreaterThanOrEqual("b", 16), + NotNull("c"), + IsNull("d"), + In("e",[1, 2, 3]), + NotIn("f", [1, 2, 3]), + NotNaN("g"), + IsNaN("h"), + StartsWith("data", "abcd"), + NotStartsWith("data", "abcd") + ] + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()), + NestedField(52, "a", IntegerType()) + ) + for expr in expressions: + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + residual_evaluator = residual_evaluator_of( + UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema + ) + assert residual_evaluator.residual_for(Record()) == expr + + +def test_in(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = In("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + assert residual == AlwaysTrue() + + +def test_in_timestamp(): + + schema = Schema( + NestedField(50, "ts", TimestampType()), + NestedField(51, "hour", IntegerType()) + ) + + + spec = PartitionSpec( + PartitionField(50, 1000, DayTransform(), "ts_part") + ) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + # assert date_20191201 == True + ts_day = day(date_20191201) + + # assert ts_day == True + + pred = In("ts", [ date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + + residual = res_eval.residual_for(Record(ts_day+3)) + assert residual == AlwaysFalse() + + +def test_not_in(): + + schema = Schema( + NestedField(50, "dateint", IntegerType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "dateint_part") + ) + + predicate = NotIn("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20180815)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170815)) + assert residual == AlwaysFalse() + + +def test_is_nan(): + schema = Schema( + NestedField(50, "double", DoubleType()), + NestedField(51, "hour", IntegerType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "double_part") + ) + + predicate = IsNaN("double") + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysFalse() + + +def test_is_not_nan(): + schema = Schema( + NestedField(50, "double", DoubleType()), + NestedField(51, "float", FloatType()) + ) + + spec = PartitionSpec( + PartitionField(50, 1050, IdentityTransform(), "double_part") + ) + + predicate = NotNaN("double") + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + + spec = PartitionSpec( + PartitionField(51, 1051, IdentityTransform(), "float_part") + ) + + predicate = NotNaN("float") + + res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + +def test_not_in_timestamp(): + + schema = Schema( + NestedField(50, "ts", TimestampType()), + NestedField(51, "dateint", IntegerType()) + ) + + + spec = PartitionSpec( + PartitionField(50, 1000, DayTransform(), "ts_part") + ) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + # assert date_20191201 == True + ts_day = day(date_20191201) + + # assert ts_day == True + + pred = NotIn("ts", [ date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + + residual = res_eval.residual_for(Record(ts_day+3)) + assert residual == AlwaysTrue() \ No newline at end of file From 3cd797deb3809570bc17698cbcd504c29473dda6 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:49:00 +0530 Subject: [PATCH 08/12] added license --- pyiceberg/expressions/residual_evaluator.py | 16 ++++++++++++++++ tests/expressions/test_residual_evaluator.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py index 4aac91e3a6..bbd954cb07 100644 --- a/pyiceberg/expressions/residual_evaluator.py +++ b/pyiceberg/expressions/residual_evaluator.py @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. from abc import ABC from pyiceberg.expressions.visitors import ( BoundBooleanExpressionVisitor, diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index f47cabac38..4af30f9827 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name import pytest from pyiceberg.expressions import ( AlwaysTrue, From 6b0924e89863950f81a7c97d41746ecb42d6d2ba Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:59:36 +0530 Subject: [PATCH 09/12] fixed lint --- pyiceberg/expressions/residual_evaluator.py | 58 +++--- tests/expressions/test_residual_evaluator.py | 176 ++++++------------- 2 files changed, 82 insertions(+), 152 deletions(-) diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py index bbd954cb07..4d382bf24d 100644 --- a/pyiceberg/expressions/residual_evaluator.py +++ b/pyiceberg/expressions/residual_evaluator.py @@ -15,26 +15,25 @@ # specific language governing permissions and limitations # under the License. from abc import ABC +from typing import Any, List, Set + +from pyiceberg.expressions import And, Or +from pyiceberg.expressions.literals import Literal from pyiceberg.expressions.visitors import ( - BoundBooleanExpressionVisitor, + AlwaysFalse, + AlwaysTrue, BooleanExpression, - UnboundPredicate, + BoundBooleanExpressionVisitor, BoundPredicate, - visit, BoundTerm, - AlwaysFalse, - AlwaysTrue -) -from pyiceberg.expressions.literals import Literal -from pyiceberg.expressions import ( - And, - Or + Not, + UnboundPredicate, + visit, ) -from pyiceberg.types import L from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema -from typing import Any, List, Set from pyiceberg.typedef import Record +from pyiceberg.types import L class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): @@ -48,12 +47,10 @@ def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, ex self.case_sensitive = case_sensitive self.expr = expr - def eval(self, partition_data: Record): self.struct = partition_data return visit(self.expr, visitor=self) - def visit_true(self) -> BooleanExpression: return AlwaysTrue() @@ -62,13 +59,13 @@ def visit_false(self) -> BooleanExpression: def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: return Not(child_result) + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return And(left_result, right_result) def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_is_null(self, term: BoundTerm[L]) -> bool: return term.eval(self.struct) is None @@ -125,12 +122,12 @@ def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: else: return self.visit_false() - def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: if term.eval(self.struct) in literals: return self.visit_true() else: return self.visit_false() + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: if term.eval(self.struct) not in literals: return self.visit_true() @@ -146,19 +143,26 @@ def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> List[str]: """ - called from eval - input + If there is no strict projection or if it evaluates to false, then return the predicate. + + Get the strict projection and inclusive projection of this predicate in partition data, + then use them to determine whether to return the original predicate. The strict projection + returns true iff the original predicate would have returned true, so the predicate can be + eliminated if the strict projection evaluates to true. Similarly the inclusive projection + returns false iff the original predicate would have returned false, so the predicate can + also be eliminated if the inclusive projection evaluates to false. + """ parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) if parts == []: return predicate from pyiceberg.types import StructType + def struct_to_schema(struct: StructType) -> Schema: return Schema(*[f for f in struct.fields]) for part in parts: - strict_projection = part.transform.strict_project(part.name, predicate) strict_result = None @@ -206,17 +210,16 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr return bound - - - class ResidualEvaluator(ResidualVisitor): def residual_for(self, partition_data): return self.eval(partition_data) -class UnpartitionedResidualEvaluator(ResidualEvaluator): - def __init__(self, schema: Schema,expr: BooleanExpression): +class UnpartitionedResidualEvaluator(ResidualEvaluator): + # Finds the residuals for an Expression the partitions in the given PartitionSpec + def __init__(self, schema: Schema, expr: BooleanExpression): from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) self.expr = expr @@ -225,12 +228,9 @@ def residual_for(self, partition_data): def residual_evaluator_of( - spec: PartitionSpec, - expr: BooleanExpression, - case_sensitive: bool, - schema: Schema + spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema ) -> ResidualEvaluator: if len(spec.fields) != 0: return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) else: - return UnpartitionedResidualEvaluator(schema=schema,expr=expr) \ No newline at end of file + return UnpartitionedResidualEvaluator(schema=schema, expr=expr) diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index 4af30f9827..d49d18eb8d 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -16,72 +16,54 @@ # under the License. # pylint:disable=redefined-outer-name import pytest + from pyiceberg.expressions import ( - AlwaysTrue, - EqualTo, - LessThan, AlwaysFalse, + AlwaysTrue, And, - Or, + EqualTo, GreaterThan, GreaterThanOrEqual, - UnboundPredicate, - BoundPredicate, - BoundReference, - BooleanExpression, - BoundLessThan, - BoundGreaterThan, - NotNull, - IsNull, In, + IsNaN, + IsNull, + LessThan, NotIn, NotNaN, - IsNaN, + NotNull, + NotStartsWith, + Or, StartsWith, - NotStartsWith + UnboundPredicate, ) +from pyiceberg.expressions.literals import literal from pyiceberg.expressions.residual_evaluator import residual_evaluator_of from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.transforms import IdentityTransform, DayTransform +from pyiceberg.transforms import DayTransform, IdentityTransform from pyiceberg.typedef import Record -from pyiceberg.types import ( - IntegerType, - DoubleType, - FloatType, - NestedField, - StringType, - TimestampType -) -from pyiceberg.utils.datetime import timestamp_to_micros -from pyiceberg.expressions.literals import literal +from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, TimestampType def test_identity_transform_residual(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = Or( Or( And(EqualTo("dateint", 20170815), LessThan("hour", 12)), - And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), ), - And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)) + And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)), ) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(dateint=20170815)) # assert residual == True assert isinstance(residual, UnboundPredicate) - assert residual.term.name == 'hour' + assert residual.term.name == "hour" # assert residual.term.field.name == 'hour' assert residual.literal.value == 12 assert type(residual) == LessThan @@ -89,7 +71,7 @@ def test_identity_transform_residual(): residual = res_eval.residual_for(Record(dateint=20170801)) assert isinstance(residual, UnboundPredicate) - assert residual.term.name == 'hour' + assert residual.term.name == "hour" assert residual.literal.value == 11 assert type(residual) == GreaterThan @@ -103,25 +85,18 @@ def test_identity_transform_residual(): def test_case_insensitive_identity_transform_residuals(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = Or( Or( And(EqualTo("DATEINT", 20170815), LessThan("HOUR", 12)), - And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)) + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), ), - And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)) + And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)), ) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) - + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) with pytest.raises(ValueError) as e: residual = res_eval.residual_for(Record(dateint=20170815)) @@ -129,8 +104,6 @@ def test_case_insensitive_identity_transform_residuals(): def test_unpartitioned_residuals(): - - expressions = [ AlwaysTrue(), AlwaysFalse(), @@ -138,41 +111,32 @@ def test_unpartitioned_residuals(): GreaterThanOrEqual("b", 16), NotNull("c"), IsNull("d"), - In("e",[1, 2, 3]), + In("e", [1, 2, 3]), NotIn("f", [1, 2, 3]), NotNaN("g"), IsNaN("h"), StartsWith("data", "abcd"), - NotStartsWith("data", "abcd") + NotStartsWith("data", "abcd"), ] schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()), - NestedField(52, "a", IntegerType()) + NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()), NestedField(52, "a", IntegerType()) ) for expr in expressions: from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC - residual_evaluator = residual_evaluator_of( - UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema - ) + + residual_evaluator = residual_evaluator_of(UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema) assert residual_evaluator.residual_for(Record()) == expr def test_in(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = In("dateint", [20170815, 20170816, 20170817]) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(dateint=20170815)) @@ -180,16 +144,9 @@ def test_in(): def test_in_timestamp(): + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "ts", TimestampType()), - NestedField(51, "hour", IntegerType()) - ) - - - spec = PartitionSpec( - PartitionField(50, 1000, DayTransform(), "ts_part") - ) + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value @@ -200,31 +157,25 @@ def test_in_timestamp(): # assert ts_day == True - pred = In("ts", [ date_20191202, date_20191201]) + pred = In("ts", [date_20191202, date_20191201]) res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(ts_day)) assert residual == pred - residual = res_eval.residual_for(Record(ts_day+3)) + residual = res_eval.residual_for(Record(ts_day + 3)) assert residual == AlwaysFalse() def test_not_in(): + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) - schema = Schema( - NestedField(50, "dateint", IntegerType()), - NestedField(51, "hour", IntegerType()) - ) - - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "dateint_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) predicate = NotIn("dateint", [20170815, 20170816, 20170817]) - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(dateint=20180815)) assert residual == AlwaysTrue() @@ -234,18 +185,13 @@ def test_not_in(): def test_is_nan(): - schema = Schema( - NestedField(50, "double", DoubleType()), - NestedField(51, "hour", IntegerType()) - ) + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "hour", IntegerType())) - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "double_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) predicate = IsNaN("double") - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(double=None)) assert residual == AlwaysTrue() @@ -255,34 +201,25 @@ def test_is_nan(): def test_is_not_nan(): - schema = Schema( - NestedField(50, "double", DoubleType()), - NestedField(51, "float", FloatType()) - ) + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "float", FloatType())) - spec = PartitionSpec( - PartitionField(50, 1050, IdentityTransform(), "double_part") - ) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) predicate = NotNaN("double") - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(double=None)) assert residual == AlwaysFalse() - residual = res_eval.residual_for(Record(double=2)) assert residual == AlwaysTrue() - - spec = PartitionSpec( - PartitionField(51, 1051, IdentityTransform(), "float_part") - ) + spec = PartitionSpec(PartitionField(51, 1051, IdentityTransform(), "float_part")) predicate = NotNaN("float") - res_eval = residual_evaluator_of(spec=spec,expr=predicate, case_sensitive=True, schema=schema) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(double=None)) assert residual == AlwaysFalse() @@ -292,16 +229,9 @@ def test_is_not_nan(): def test_not_in_timestamp(): + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "dateint", IntegerType())) - schema = Schema( - NestedField(50, "ts", TimestampType()), - NestedField(51, "dateint", IntegerType()) - ) - - - spec = PartitionSpec( - PartitionField(50, 1000, DayTransform(), "ts_part") - ) + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value @@ -312,12 +242,12 @@ def test_not_in_timestamp(): # assert ts_day == True - pred = NotIn("ts", [ date_20191202, date_20191201]) + pred = NotIn("ts", [date_20191202, date_20191201]) res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) residual = res_eval.residual_for(Record(ts_day)) assert residual == pred - residual = res_eval.residual_for(Record(ts_day+3)) - assert residual == AlwaysTrue() \ No newline at end of file + residual = res_eval.residual_for(Record(ts_day + 3)) + assert residual == AlwaysTrue() From 96cb4e9977513e5ecd95722fa644ef84e61bc00f Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 24 Dec 2024 13:42:32 +0530 Subject: [PATCH 10/12] fixed lint errors --- pyiceberg/expressions/residual_evaluator.py | 59 ++++++++++++-------- tests/expressions/test_residual_evaluator.py | 18 +++--- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py index 4d382bf24d..025772f627 100644 --- a/pyiceberg/expressions/residual_evaluator.py +++ b/pyiceberg/expressions/residual_evaluator.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from abc import ABC -from typing import Any, List, Set +from typing import Any, Set from pyiceberg.expressions import And, Or from pyiceberg.expressions.literals import Literal @@ -47,7 +47,7 @@ def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, ex self.case_sensitive = case_sensitive self.expr = expr - def eval(self, partition_data: Record): + def eval(self, partition_data: Record) -> BooleanExpression: self.struct = partition_data return visit(self.expr, visitor=self) @@ -66,82 +66,94 @@ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpress def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: return Or(left_result, right_result) - def visit_is_null(self, term: BoundTerm[L]) -> bool: - return term.eval(self.struct) is None + def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is None: + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_not_null(self, term: BoundTerm[L]) -> bool: - return term.eval(self.struct) is not None + def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is not None: + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_is_nan(self, term: BoundTerm[L]) -> bool: + def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: val = term.eval(self.struct) if val is None: return self.visit_true() else: return self.visit_false() - def visit_not_nan(self, term: BoundTerm[L]) -> bool: + def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: val = term.eval(self.struct) if val is not None: return self.visit_true() else: return self.visit_false() - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) < literal.value: return self.visit_true() else: return self.visit_false() - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) <= literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) > literal.value: return self.visit_true() else: return self.visit_false() - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) >= literal.value: return self.visit_true() else: return self.visit_false() - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) == literal.value: return self.visit_true() else: return self.visit_false() - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: if term.eval(self.struct) != literal.value: return self.visit_true() else: return self.visit_false() - def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: if term.eval(self.struct) in literals: return self.visit_true() else: return self.visit_false() - def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: if term.eval(self.struct) not in literals: return self.visit_true() else: return self.visit_false() - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: eval_res = term.eval(self.struct) - return eval_res is not None and str(eval_res).startswith(str(literal.value)) + if eval_res is not None and str(eval_res).startswith(str(literal.value)): + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> bool: - return not self.visit_starts_with(term, literal) + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if not self.visit_starts_with(term, literal): + return AlwaysTrue() + else: + return AlwaysFalse() - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> List[str]: + def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: """ If there is no strict projection or if it evaluates to false, then return the predicate. @@ -168,7 +180,6 @@ def struct_to_schema(struct: StructType) -> Schema: if strict_projection is not None: bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) - assert isinstance(bound, BoundPredicate) if isinstance(bound, BoundPredicate): strict_result = super().visit_bound_predicate(bound) else: @@ -211,7 +222,7 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr class ResidualEvaluator(ResidualVisitor): - def residual_for(self, partition_data): + def residual_for(self, partition_data: Record) -> BooleanExpression: return self.eval(partition_data) @@ -223,7 +234,7 @@ def __init__(self, schema: Schema, expr: BooleanExpression): super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) self.expr = expr - def residual_for(self, partition_data): + def residual_for(self, partition_data: Record) -> BooleanExpression: return self.expr diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index d49d18eb8d..c7210eaf01 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -45,7 +45,7 @@ from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, TimestampType -def test_identity_transform_residual(): +def test_identity_transform_residual() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -84,7 +84,7 @@ def test_identity_transform_residual(): assert residual == AlwaysFalse() -def test_case_insensitive_identity_transform_residuals(): +def test_case_insensitive_identity_transform_residuals() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -103,7 +103,7 @@ def test_case_insensitive_identity_transform_residuals(): assert "Could not find field with name DATEINT, case_sensitive=True" in str(e.value) -def test_unpartitioned_residuals(): +def test_unpartitioned_residuals() -> None: expressions = [ AlwaysTrue(), AlwaysFalse(), @@ -129,7 +129,7 @@ def test_unpartitioned_residuals(): assert residual_evaluator.residual_for(Record()) == expr -def test_in(): +def test_in() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -143,7 +143,7 @@ def test_in(): assert residual == AlwaysTrue() -def test_in_timestamp(): +def test_in_timestamp() -> None: schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) @@ -168,7 +168,7 @@ def test_in_timestamp(): assert residual == AlwaysFalse() -def test_not_in(): +def test_not_in() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) @@ -184,7 +184,7 @@ def test_not_in(): assert residual == AlwaysFalse() -def test_is_nan(): +def test_is_nan() -> None: schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "hour", IntegerType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) @@ -200,7 +200,7 @@ def test_is_nan(): assert residual == AlwaysFalse() -def test_is_not_nan(): +def test_is_not_nan() -> None: schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "float", FloatType())) spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) @@ -228,7 +228,7 @@ def test_is_not_nan(): assert residual == AlwaysTrue() -def test_not_in_timestamp(): +def test_not_in_timestamp() -> None: schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "dateint", IntegerType())) spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) From 8bb039f086244decbdfac3cf9bc43e2fce6afeef Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Tue, 31 Dec 2024 17:08:26 +0530 Subject: [PATCH 11/12] Gh 1223 metadata only row count (#4) * added residual evaluator in plan files * tested counts with positional deletes * merged main --- pyiceberg/table/__init__.py | 34 ++++- tests/integration/test_delete_count.py | 189 +++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 3 deletions(-) create mode 100644 tests/integration/test_delete_count.py diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 142848ba4b..a590454c19 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1344,6 +1344,7 @@ class FileScanTask(ScanTask): delete_files: Set[DataFile] start: int length: int + residual: BooleanExpression def __init__( self, @@ -1351,12 +1352,13 @@ def __init__( delete_files: Optional[Set[DataFile]] = None, start: Optional[int] = None, length: Optional[int] = None, + residual: BooleanExpression = None ) -> None: self.file = data_file self.delete_files = delete_files or set() self.start = start or 0 self.length = length or data_file.file_size_in_bytes - + self.residual = residual def _open_manifest( io: FileIO, @@ -1516,13 +1518,23 @@ def plan_files(self) -> Iterable[FileScanTask]: else: raise ValueError(f"Unknown DataFileContent ({data_file.content}): {manifest_entry}") + + + from pyiceberg.expressions.residual_evaluator import residual_evaluator_of + residual_evaluator = residual_evaluator_of( + spec=self.table_metadata.spec(), + expr=self.row_filter, + case_sensitive=self.case_sensitive, + schema=self.table_metadata.schema() + ) return [ FileScanTask( - data_entry.data_file, + data_file=data_entry.data_file, delete_files=_match_deletes_to_data_file( data_entry, positional_delete_entries, ), + residual=residual_evaluator.residual_for(data_entry.data_file.partition) ) for data_entry in data_entries ] @@ -1598,10 +1610,26 @@ def to_ray(self) -> ray.data.dataset.Dataset: return ray.data.from_arrow(self.to_arrow()) def count(self) -> int: + """ + Usage: calutates the total number of records in a Scan that haven't had positional deletes + """ res = 0 + # every task is a FileScanTask tasks = self.plan_files() + for task in tasks: - res += task.file.record_count + # task.residual is a Boolean Expression if the fiter condition is fully satisfied by the + # partition value and task.delete_files represents that positional delete haven't been merged yet + # hence those files have to read as a pyarrow table applying the filter and deletes + if task.residual == AlwaysTrue() and not len(task.delete_files): + # Every File has a metadata stat that stores the file record count + res += task.file.record_count + else: + from pyiceberg.io.pyarrow import ArrowScan + tbl = ArrowScan( + self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + ).to_table([task]) + res += len(tbl) return res diff --git a/tests/integration/test_delete_count.py b/tests/integration/test_delete_count.py new file mode 100644 index 0000000000..781f0513c6 --- /dev/null +++ b/tests/integration/test_delete_count.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +from datetime import datetime +from typing import Generator, List + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession + +from pyiceberg.catalog.rest import RestCatalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.expressions import AlwaysTrue, EqualTo +from pyiceberg.manifest import ManifestEntryStatus +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.table.snapshots import Operation, Summary +from pyiceberg.transforms import IdentityTransform +from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, StringType, TimestampType + + +def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None: + for sql in sqls: + spark.sql(sql) + + +@pytest.fixture() +def test_table(session_catalog: RestCatalog) -> Generator[Table, None, None]: + identifier = "default.__test_table" + arrow_table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5]), pa.array(["a", "b", "c", "d", "e"])], names=["idx", "value"]) + test_table = session_catalog.create_table( + identifier, + schema=Schema( + NestedField(1, "idx", LongType()), + NestedField(2, "value", StringType()), + ), + ) + test_table.append(arrow_table) + + yield test_table + + session_catalog.drop_table(identifier) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_partitioned_table_delete_full_file(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None: + identifier = "default.table_partitioned_delete" + run_spark_commands( + spark, + [ + f"DROP TABLE IF EXISTS {identifier}", + f""" + CREATE TABLE {identifier} ( + number_partitioned int, + number int + ) + USING iceberg + PARTITIONED BY (number_partitioned) + TBLPROPERTIES('format-version' = {format_version}) + """, + f""" + INSERT INTO {identifier} VALUES (10, 20), (10, 30) + """, + f""" + INSERT INTO {identifier} VALUES (11, 20), (11, 30) + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + tbl.delete(EqualTo("number_partitioned", 10)) + + # No overwrite operation + assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "append", "delete"] + assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 11], "number": [20, 30]} + + assert tbl.scan().count() == len(tbl.scan().to_arrow()) + + +@pytest.mark.integration +@pytest.mark.filterwarnings("ignore:Merge on read is not yet supported, falling back to copy-on-write") +def test_delete_partitioned_table_positional_deletes(spark: SparkSession, session_catalog: RestCatalog) -> None: + identifier = "default.table_partitioned_delete" + + run_spark_commands( + spark, + [ + f"DROP TABLE IF EXISTS {identifier}", + f""" + CREATE TABLE {identifier} ( + number_partitioned int, + number int + ) + USING iceberg + PARTITIONED BY (number_partitioned) + TBLPROPERTIES( + 'format-version' = 2, + 'write.delete.mode'='merge-on-read', + 'write.update.mode'='merge-on-read', + 'write.merge.mode'='merge-on-read' + ) + """, + f""" + INSERT INTO {identifier} VALUES (10, 20), (10, 30), (10, 40) + """, + # Generate a positional delete + f""" + DELETE FROM {identifier} WHERE number = 30 + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + + assert tbl.scan().count() == len(tbl.scan().to_arrow()) + + # Will rewrite a data file without the positional delete + tbl.delete(EqualTo("number", 40)) + + assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "overwrite", "overwrite"] + assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [10], "number": [20]} + + assert tbl.scan().count() == len(tbl.scan().to_arrow()) + + run_spark_commands( + spark, + [ + f""" + INSERT INTO {identifier} VALUES (10, 20), (10, 30), (10, 40), (20, 30) + """, + # Generate a positional delete + f""" + DELETE FROM {identifier} WHERE number = 30 + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + + assert tbl.scan().count() == len(tbl.scan().to_arrow()) + + + run_spark_commands( + spark, + [ + # Generate a positional delete + f""" + DELETE FROM {identifier} WHERE number = 30 + """, + f""" + INSERT INTO {identifier} VALUES (10, 20), (10, 30), (10, 40), (20, 30) + """, + # Generate a positional delete + f""" + DELETE FROM {identifier} WHERE number = 20 + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + + assert tbl.scan().count() == len(tbl.scan().to_arrow()) + + + filter_on_partition = "number_partitioned = 10" + scan_on_partition = tbl.scan(row_filter=filter_on_partition) + assert scan_on_partition.count() == len(scan_on_partition.to_arrow()) + + + filter = "number = 10" + scan = tbl.scan(row_filter=filter) + assert scan.count() == len(scan.to_arrow()) + From ab4c000f43364651e279f0b0717989ddbbc2b268 Mon Sep 17 00:00:00 2001 From: Tushar Choudhary <151359025+tusharchou@users.noreply.github.com> Date: Mon, 6 Jan 2025 15:44:34 +0530 Subject: [PATCH 12/12] Gh 1223 metadata only row count (#5) * added residual evaluator in plan files * tested counts with positional deletes * merged main * implemented batch reader in count * breaking integration test * fixed integration test * git pull main * revert * revert * revert test_partitioning_key.py * revert test_parser.py * added residual evaluator in visitor * deleted residual_evaluator.py * removed test count from test_sql.py * ignored lint type * fixed lint * working on plan_files * type ignored * make lint --- pyiceberg/expressions/residual_evaluator.py | 247 ------------------- pyiceberg/expressions/visitors.py | 213 +++++++++++++++- pyiceberg/table/__init__.py | 89 ++++--- tests/catalog/test_sql.py | 52 ---- tests/expressions/test_residual_evaluator.py | 40 ++- tests/integration/test_delete_count.py | 160 +++++------- 6 files changed, 357 insertions(+), 444 deletions(-) delete mode 100644 pyiceberg/expressions/residual_evaluator.py diff --git a/pyiceberg/expressions/residual_evaluator.py b/pyiceberg/expressions/residual_evaluator.py deleted file mode 100644 index 025772f627..0000000000 --- a/pyiceberg/expressions/residual_evaluator.py +++ /dev/null @@ -1,247 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from abc import ABC -from typing import Any, Set - -from pyiceberg.expressions import And, Or -from pyiceberg.expressions.literals import Literal -from pyiceberg.expressions.visitors import ( - AlwaysFalse, - AlwaysTrue, - BooleanExpression, - BoundBooleanExpressionVisitor, - BoundPredicate, - BoundTerm, - Not, - UnboundPredicate, - visit, -) -from pyiceberg.partitioning import PartitionSpec -from pyiceberg.schema import Schema -from pyiceberg.typedef import Record -from pyiceberg.types import L - - -class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): - schema: Schema - spec: PartitionSpec - case_sensitive: bool - - def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, expr: BooleanExpression): - self.schema = schema - self.spec = spec - self.case_sensitive = case_sensitive - self.expr = expr - - def eval(self, partition_data: Record) -> BooleanExpression: - self.struct = partition_data - return visit(self.expr, visitor=self) - - def visit_true(self) -> BooleanExpression: - return AlwaysTrue() - - def visit_false(self) -> BooleanExpression: - return AlwaysFalse() - - def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: - return Not(child_result) - - def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: - return And(left_result, right_result) - - def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: - return Or(left_result, right_result) - - def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: - if term.eval(self.struct) is None: - return AlwaysTrue() - else: - return AlwaysFalse() - - def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: - if term.eval(self.struct) is not None: - return AlwaysTrue() - else: - return AlwaysFalse() - - def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: - val = term.eval(self.struct) - if val is None: - return self.visit_true() - else: - return self.visit_false() - - def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: - val = term.eval(self.struct) - if val is not None: - return self.visit_true() - else: - return self.visit_false() - - def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - if term.eval(self.struct) < literal.value: - return self.visit_true() - else: - return self.visit_false() - - def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - if term.eval(self.struct) <= literal.value: - return self.visit_true() - else: - return self.visit_false() - - def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - if term.eval(self.struct) > literal.value: - return self.visit_true() - else: - return self.visit_false() - - def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - if term.eval(self.struct) >= literal.value: - return self.visit_true() - else: - return self.visit_false() - - def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - if term.eval(self.struct) == literal.value: - return self.visit_true() - else: - return self.visit_false() - - def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - if term.eval(self.struct) != literal.value: - return self.visit_true() - else: - return self.visit_false() - - def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: - if term.eval(self.struct) in literals: - return self.visit_true() - else: - return self.visit_false() - - def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: - if term.eval(self.struct) not in literals: - return self.visit_true() - else: - return self.visit_false() - - def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - eval_res = term.eval(self.struct) - if eval_res is not None and str(eval_res).startswith(str(literal.value)): - return AlwaysTrue() - else: - return AlwaysFalse() - - def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: - if not self.visit_starts_with(term, literal): - return AlwaysTrue() - else: - return AlwaysFalse() - - def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: - """ - If there is no strict projection or if it evaluates to false, then return the predicate. - - Get the strict projection and inclusive projection of this predicate in partition data, - then use them to determine whether to return the original predicate. The strict projection - returns true iff the original predicate would have returned true, so the predicate can be - eliminated if the strict projection evaluates to true. Similarly the inclusive projection - returns false iff the original predicate would have returned false, so the predicate can - also be eliminated if the inclusive projection evaluates to false. - - """ - parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) - if parts == []: - return predicate - - from pyiceberg.types import StructType - - def struct_to_schema(struct: StructType) -> Schema: - return Schema(*[f for f in struct.fields]) - - for part in parts: - strict_projection = part.transform.strict_project(part.name, predicate) - strict_result = None - - if strict_projection is not None: - bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) - if isinstance(bound, BoundPredicate): - strict_result = super().visit_bound_predicate(bound) - else: - strict_result = bound - - if strict_result is not None and isinstance(strict_result, AlwaysTrue): - return AlwaysTrue() - - inclusive_projection = part.transform.project(part.name, predicate) - inclusive_result = None - if inclusive_projection is not None: - bound_inclusive = inclusive_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) - if isinstance(bound_inclusive, BoundPredicate): - # using predicate method specific to inclusive - inclusive_result = super().visit_bound_predicate(bound_inclusive) - else: - # if the result is not a predicate, then it must be a constant like alwaysTrue or - # alwaysFalse - inclusive_result = bound_inclusive - if inclusive_result is not None and isinstance(inclusive_result, AlwaysFalse): - return AlwaysFalse() - - return predicate - - def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: - bound = predicate.bind(self.schema, case_sensitive=True) - - if isinstance(bound, BoundPredicate): - bound_residual = self.visit_bound_predicate(predicate=bound) - # if isinstance(bound_residual, BooleanExpression): - if bound_residual not in (AlwaysFalse(), AlwaysTrue()): - # replace inclusive original unbound predicate - return predicate - - # use the non-predicate residual (e.g. alwaysTrue) - return bound_residual - - # if binding didn't result in a Predicate, return the expression - return bound - - -class ResidualEvaluator(ResidualVisitor): - def residual_for(self, partition_data: Record) -> BooleanExpression: - return self.eval(partition_data) - - -class UnpartitionedResidualEvaluator(ResidualEvaluator): - # Finds the residuals for an Expression the partitions in the given PartitionSpec - def __init__(self, schema: Schema, expr: BooleanExpression): - from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC - - super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) - self.expr = expr - - def residual_for(self, partition_data: Record) -> BooleanExpression: - return self.expr - - -def residual_evaluator_of( - spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema -) -> ResidualEvaluator: - if len(spec.fields) != 0: - return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) - else: - return UnpartitionedResidualEvaluator(schema=schema, expr=expr) diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 768878b068..a5e931f294 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -62,7 +62,7 @@ from pyiceberg.manifest import DataFile, ManifestFile, PartitionFieldSummary from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.typedef import EMPTY_DICT, L, StructProtocol +from pyiceberg.typedef import EMPTY_DICT, L, Record, StructProtocol from pyiceberg.types import ( DoubleType, FloatType, @@ -1731,3 +1731,214 @@ def _can_contain_nulls(self, field_id: int) -> bool: def _can_contain_nans(self, field_id: int) -> bool: return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0 + + +class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): + schema: Schema + spec: PartitionSpec + case_sensitive: bool + + def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, expr: BooleanExpression): + self.schema = schema + self.spec = spec + self.case_sensitive = case_sensitive + self.expr = expr + + def eval(self, partition_data: Record) -> BooleanExpression: + self.struct = partition_data + return visit(self.expr, visitor=self) + + def visit_true(self) -> BooleanExpression: + return AlwaysTrue() + + def visit_false(self) -> BooleanExpression: + return AlwaysFalse() + + def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: + return Not(child_result) + + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return And(left_result, right_result) + + def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return Or(left_result, right_result) + + def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is None: + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is not None: + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: + val = term.eval(self.struct) + if val is None: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: + val = term.eval(self.struct) + if val is not None: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) < literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) <= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) > literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) >= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) == literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) != literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: + if term.eval(self.struct) in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: + if term.eval(self.struct) not in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + eval_res = term.eval(self.struct) + if eval_res is not None and str(eval_res).startswith(str(literal.value)): + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if not self.visit_starts_with(term, literal): + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + """ + If there is no strict projection or if it evaluates to false, then return the predicate. + + Get the strict projection and inclusive projection of this predicate in partition data, + then use them to determine whether to return the original predicate. The strict projection + returns true iff the original predicate would have returned true, so the predicate can be + eliminated if the strict projection evaluates to true. Similarly the inclusive projection + returns false iff the original predicate would have returned false, so the predicate can + also be eliminated if the inclusive projection evaluates to false. + + """ + parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) + if parts == []: + return predicate + + from pyiceberg.types import StructType + + def struct_to_schema(struct: StructType) -> Schema: + return Schema(*list(struct.fields)) + + for part in parts: + strict_projection = part.transform.strict_project(part.name, predicate) + strict_result = None + + if strict_projection is not None: + bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound, BoundPredicate): + strict_result = super().visit_bound_predicate(bound) + else: + strict_result = bound + + if strict_result is not None and isinstance(strict_result, AlwaysTrue): + return AlwaysTrue() + + inclusive_projection = part.transform.project(part.name, predicate) + inclusive_result = None + if inclusive_projection is not None: + bound_inclusive = inclusive_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound_inclusive, BoundPredicate): + # using predicate method specific to inclusive + inclusive_result = super().visit_bound_predicate(bound_inclusive) + else: + # if the result is not a predicate, then it must be a constant like alwaysTrue or + # alwaysFalse + inclusive_result = bound_inclusive + if inclusive_result is not None and isinstance(inclusive_result, AlwaysFalse): + return AlwaysFalse() + + return predicate + + def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + bound = predicate.bind(self.schema, case_sensitive=True) + + if isinstance(bound, BoundPredicate): + bound_residual = self.visit_bound_predicate(predicate=bound) + # if isinstance(bound_residual, BooleanExpression): + if bound_residual not in (AlwaysFalse(), AlwaysTrue()): + # replace inclusive original unbound predicate + return predicate + + # use the non-predicate residual (e.g. alwaysTrue) + return bound_residual + + # if binding didn't result in a Predicate, return the expression + return bound + + +class ResidualEvaluator(ResidualVisitor): + def residual_for(self, partition_data: Record) -> BooleanExpression: + return self.eval(partition_data) + + +class UnpartitionedResidualEvaluator(ResidualEvaluator): + # Finds the residuals for an Expression the partitions in the given PartitionSpec + def __init__(self, schema: Schema, expr: BooleanExpression): + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + + super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) + self.expr = expr + + def residual_for(self, partition_data: Record) -> BooleanExpression: + return self.expr + + +def residual_evaluator_of( + spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema +) -> ResidualEvaluator: + if len(spec.fields) != 0: + return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) + else: + return UnpartitionedResidualEvaluator(schema=schema, expr=expr) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 5e4ffb0d0d..2a51d7d5cb 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1354,27 +1354,29 @@ def __init__( delete_files: Optional[Set[DataFile]] = None, start: Optional[int] = None, length: Optional[int] = None, - residual: BooleanExpression = None + residual: Optional[BooleanExpression] = None, ) -> None: self.file = data_file self.delete_files = delete_files or set() self.start = start or 0 self.length = length or data_file.file_size_in_bytes - self.residual = residual + self.residual = residual # type: ignore + def _open_manifest( io: FileIO, manifest: ManifestFile, partition_filter: Callable[[DataFile], bool], + residual_evaluator: Callable[[Record], BooleanExpression], metrics_evaluator: Callable[[DataFile], bool], -) -> List[ManifestEntry]: +) -> List[tuple[ManifestEntry, BooleanExpression]]: """Open a manifest file and return matching manifest entries. Returns: A list of ManifestEntry that matches the provided filters. """ return [ - manifest_entry + (manifest_entry, residual_evaluator(manifest_entry.data_file.partition)) for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True) if partition_filter(manifest_entry.data_file) and metrics_evaluator(manifest_entry.data_file) ] @@ -1441,6 +1443,27 @@ def _build_partition_evaluator(self, spec_id: int) -> Callable[[DataFile], bool] # shared instance across multiple threads. return lambda data_file: expression_evaluator(partition_schema, partition_expr, self.case_sensitive)(data_file.partition) + from pyiceberg.expressions.visitors import ResidualEvaluator + + def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], ResidualEvaluator]: + spec = self.table_metadata.specs()[spec_id] + + # The lambda created here is run in multiple threads. + # So we avoid creating _EvaluatorExpression methods bound to a single + # shared instance across multiple threads. + # return lambda data_file: (partition_schema, partition_expr, self.case_sensitive)(data_file.partition) + from pyiceberg.expressions.visitors import residual_evaluator_of + + # assert self.row_filter == False + return lambda datafile: ( + residual_evaluator_of( + spec=spec, + expr=self.row_filter, + case_sensitive=self.case_sensitive, + schema=self.table_metadata.schema(), + ) + ) + def _check_sequence_number(self, min_sequence_number: int, manifest: ManifestFile) -> bool: """Ensure that no manifests are loaded that contain deletes that are older than the data. @@ -1471,6 +1494,9 @@ def plan_files(self) -> Iterable[FileScanTask]: # the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator) + from pyiceberg.expressions.visitors import ResidualEvaluator + + residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator) manifests = [ manifest_file @@ -1482,6 +1508,7 @@ def plan_files(self) -> Iterable[FileScanTask]: # this filter depends on the partition spec used to write the manifest file partition_evaluators: Dict[int, Callable[[DataFile], bool]] = KeyDefaultDict(self._build_partition_evaluator) + metrics_evaluator = _InclusiveMetricsEvaluator( self.table_metadata.schema(), self.row_filter, @@ -1491,11 +1518,11 @@ def plan_files(self) -> Iterable[FileScanTask]: min_sequence_number = _min_sequence_number(manifests) - data_entries: List[ManifestEntry] = [] + data_entries: List[tuple[ManifestEntry, BooleanExpression]] = [] positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER) executor = ExecutorFactory.get_or_create() - for manifest_entry in chain( + for manifest_entry, residual in chain( *executor.map( lambda args: _open_manifest(*args), [ @@ -1503,6 +1530,7 @@ def plan_files(self) -> Iterable[FileScanTask]: self.io, manifest, partition_evaluators[manifest.partition_spec_id], + residual_evaluators[manifest.partition_spec_id], metrics_evaluator, ) for manifest in manifests @@ -1512,7 +1540,7 @@ def plan_files(self) -> Iterable[FileScanTask]: ): data_file = manifest_entry.data_file if data_file.content == DataFileContent.DATA: - data_entries.append(manifest_entry) + data_entries.append((manifest_entry, residual)) elif data_file.content == DataFileContent.POSITION_DELETES: positional_delete_entries.add(manifest_entry) elif data_file.content == DataFileContent.EQUALITY_DELETES: @@ -1520,25 +1548,16 @@ def plan_files(self) -> Iterable[FileScanTask]: else: raise ValueError(f"Unknown DataFileContent ({data_file.content}): {manifest_entry}") - - - from pyiceberg.expressions.residual_evaluator import residual_evaluator_of - residual_evaluator = residual_evaluator_of( - spec=self.table_metadata.spec(), - expr=self.row_filter, - case_sensitive=self.case_sensitive, - schema=self.table_metadata.schema() - ) return [ FileScanTask( - data_file=data_entry.data_file, + data_entry.data_file, delete_files=_match_deletes_to_data_file( data_entry, positional_delete_entries, ), - residual=residual_evaluator.residual_for(data_entry.data_file.partition) + residual=residual, ) - for data_entry in data_entries + for data_entry, residual in data_entries ] def to_arrow(self) -> pa.Table: @@ -1612,26 +1631,40 @@ def to_ray(self) -> ray.data.dataset.Dataset: return ray.data.from_arrow(self.to_arrow()) def count(self) -> int: - """ - Usage: calutates the total number of records in a Scan that haven't had positional deletes - """ + # Usage: Calculates the total number of records in a Scan that haven't had positional deletes. res = 0 # every task is a FileScanTask tasks = self.plan_files() for task in tasks: - # task.residual is a Boolean Expression if the fiter condition is fully satisfied by the + # task.residual is a Boolean Expression if the filter condition is fully satisfied by the # partition value and task.delete_files represents that positional delete haven't been merged yet # hence those files have to read as a pyarrow table applying the filter and deletes if task.residual == AlwaysTrue() and not len(task.delete_files): # Every File has a metadata stat that stores the file record count res += task.file.record_count else: - from pyiceberg.io.pyarrow import ArrowScan - tbl = ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit - ).to_table([task]) - res += len(tbl) + from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow + + arrow_scan = ArrowScan( + table_metadata=self.table_metadata, + io=self.io, + projected_schema=self.projection(), + row_filter=self.row_filter, + case_sensitive=self.case_sensitive, + limit=self.limit, + ) + if task.file.file_size_in_bytes > 512 * 1024 * 1024: + target_schema = schema_to_pyarrow(self.projection()) + batches = arrow_scan.to_record_batches([task]) + from pyarrow import RecordBatchReader + + reader = RecordBatchReader.from_batches(target_schema, batches) + for batch in reader: + res += batch.num_rows + else: + tbl = arrow_scan.to_table([task]) + res += len(tbl) return res diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 45e18a23f2..cffc14d9d7 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -1362,58 +1362,6 @@ def test_append_table(catalog: SqlCatalog, table_schema_simple: Schema, table_id assert df == table.scan().to_arrow() -@pytest.mark.parametrize( - "catalog", - [ - lazy_fixture("catalog_memory"), - lazy_fixture("catalog_sqlite"), - lazy_fixture("catalog_sqlite_without_rowcount"), - lazy_fixture("catalog_sqlite_fsspec"), - ], -) -@pytest.mark.parametrize( - "table_identifier", - [ - lazy_fixture("random_table_identifier"), - lazy_fixture("random_hierarchical_identifier"), - lazy_fixture("random_table_identifier_with_catalog"), - ], -) -def test_count_table(catalog: SqlCatalog, table_schema_simple: Schema, table_identifier: Identifier) -> None: - table_identifier_nocatalog = catalog._identifier_to_tuple_without_catalog(table_identifier) - namespace = Catalog.namespace_from(table_identifier_nocatalog) - catalog.create_namespace(namespace) - table = catalog.create_table(table_identifier, table_schema_simple) - - df = pa.Table.from_pydict( - { - "foo": ["a"], - "bar": [1], - "baz": [True], - }, - schema=schema_to_pyarrow(table_schema_simple), - ) - - table.append(df) - - # new snapshot is written in APPEND mode - assert len(table.metadata.snapshots) == 1 - assert table.metadata.snapshots[0].snapshot_id == table.metadata.current_snapshot_id - assert table.metadata.snapshots[0].parent_snapshot_id is None - assert table.metadata.snapshots[0].sequence_number == 1 - assert table.metadata.snapshots[0].summary is not None - assert table.metadata.snapshots[0].summary.operation == Operation.APPEND - assert table.metadata.snapshots[0].summary["added-data-files"] == "1" - assert table.metadata.snapshots[0].summary["added-records"] == "1" - assert table.metadata.snapshots[0].summary["total-data-files"] == "1" - assert table.metadata.snapshots[0].summary["total-records"] == "1" - assert len(table.metadata.metadata_log) == 1 - - # read back the data - assert df == table.scan().to_arrow() - assert len(table.scan().to_arrow()) == table.scan().count() - - @pytest.mark.parametrize( "catalog", [ diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index c7210eaf01..60588d8d9c 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -34,10 +34,9 @@ NotStartsWith, Or, StartsWith, - UnboundPredicate, ) from pyiceberg.expressions.literals import literal -from pyiceberg.expressions.residual_evaluator import residual_evaluator_of +from pyiceberg.expressions.visitors import residual_evaluator_of from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.transforms import DayTransform, IdentityTransform @@ -62,18 +61,23 @@ def test_identity_transform_residual() -> None: residual = res_eval.residual_for(Record(dateint=20170815)) # assert residual == True - assert isinstance(residual, UnboundPredicate) - assert residual.term.name == "hour" + assert isinstance(residual, LessThan) + assert residual.term.name == "hour" # type: ignore # assert residual.term.field.name == 'hour' assert residual.literal.value == 12 - assert type(residual) == LessThan + assert type(residual) is LessThan residual = res_eval.residual_for(Record(dateint=20170801)) - assert isinstance(residual, UnboundPredicate) - assert residual.term.name == "hour" - assert residual.literal.value == 11 - assert type(residual) == GreaterThan + # assert isinstance(residual, UnboundPredicate) + from pyiceberg.expressions import LiteralPredicate + + assert isinstance(residual, LiteralPredicate) + # assert isinstance(residual, GreaterThan) + assert residual.term.name == "hour" # type: ignore + # assert residual.term. + assert residual.literal.value == 11 # type :ignore + # assert type(residual) == BoundGreaterThan residual = res_eval.residual_for(Record(dateint=20170812)) @@ -99,7 +103,7 @@ def test_case_insensitive_identity_transform_residuals() -> None: res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) with pytest.raises(ValueError) as e: - residual = res_eval.residual_for(Record(dateint=20170815)) + res_eval.residual_for(Record(dateint=20170815)) assert "Could not find field with name DATEINT, case_sensitive=True" in str(e.value) @@ -152,10 +156,7 @@ def test_in_timestamp() -> None: date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value day = DayTransform().transform(TimestampType()) - # assert date_20191201 == True - ts_day = day(date_20191201) - - # assert ts_day == True + ts_day = day(date_20191201) # type: ignore pred = In("ts", [date_20191202, date_20191201]) @@ -164,7 +165,7 @@ def test_in_timestamp() -> None: residual = res_eval.residual_for(Record(ts_day)) assert residual == pred - residual = res_eval.residual_for(Record(ts_day + 3)) + residual = res_eval.residual_for(Record(ts_day + 3)) # type: ignore assert residual == AlwaysFalse() @@ -237,10 +238,7 @@ def test_not_in_timestamp() -> None: date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value day = DayTransform().transform(TimestampType()) - # assert date_20191201 == True - ts_day = day(date_20191201) - - # assert ts_day == True + ts_day = day(date_20191201) # type: ignore pred = NotIn("ts", [date_20191202, date_20191201]) @@ -248,6 +246,6 @@ def test_not_in_timestamp() -> None: residual = res_eval.residual_for(Record(ts_day)) assert residual == pred - - residual = res_eval.residual_for(Record(ts_day + 3)) + ts_day += 3 # type: ignore + residual = res_eval.residual_for(Record(ts_day)) assert residual == AlwaysTrue() diff --git a/tests/integration/test_delete_count.py b/tests/integration/test_delete_count.py index 781f0513c6..0ba9d2d6da 100644 --- a/tests/integration/test_delete_count.py +++ b/tests/integration/test_delete_count.py @@ -15,23 +15,23 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name -from datetime import datetime +import random +from datetime import datetime, timedelta from typing import Generator, List import pyarrow as pa import pytest +from pyarrow import compute as pc from pyspark.sql import SparkSession +from pyiceberg.catalog import Catalog from pyiceberg.catalog.rest import RestCatalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, EqualTo -from pyiceberg.manifest import ManifestEntryStatus -from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, LessThan from pyiceberg.schema import Schema from pyiceberg.table import Table -from pyiceberg.table.snapshots import Operation, Summary -from pyiceberg.transforms import IdentityTransform -from pyiceberg.types import FloatType, IntegerType, LongType, NestedField, StringType, TimestampType +from pyiceberg.transforms import HourTransform, IdentityTransform +from pyiceberg.types import LongType, NestedField, StringType def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None: @@ -67,8 +67,8 @@ def test_partitioned_table_delete_full_file(spark: SparkSession, session_catalog f"DROP TABLE IF EXISTS {identifier}", f""" CREATE TABLE {identifier} ( - number_partitioned int, - number int + number_partitioned long, + number long ) USING iceberg PARTITIONED BY (number_partitioned) @@ -91,99 +91,69 @@ def test_partitioned_table_delete_full_file(spark: SparkSession, session_catalog assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 11], "number": [20, 30]} assert tbl.scan().count() == len(tbl.scan().to_arrow()) + filter = And(EqualTo("number_partitioned", 11), GreaterThanOrEqual("number", 5)) + assert tbl.scan(filter).count() == len(tbl.scan(filter).to_arrow()) + N = 10 + d = { + "number_partitioned": pa.array([i * 10 for i in range(N)]), + "number": pa.array([random.choice([10, 20, 40]) for _ in range(N)]), + } + with tbl.update_spec() as update: + update.add_field("number", transform=IdentityTransform()) + data = pa.Table.from_pydict(d) -@pytest.mark.integration -@pytest.mark.filterwarnings("ignore:Merge on read is not yet supported, falling back to copy-on-write") -def test_delete_partitioned_table_positional_deletes(spark: SparkSession, session_catalog: RestCatalog) -> None: - identifier = "default.table_partitioned_delete" + tbl.overwrite(df=data, overwrite_filter=filter) - run_spark_commands( - spark, - [ - f"DROP TABLE IF EXISTS {identifier}", - f""" - CREATE TABLE {identifier} ( - number_partitioned int, - number int - ) - USING iceberg - PARTITIONED BY (number_partitioned) - TBLPROPERTIES( - 'format-version' = 2, - 'write.delete.mode'='merge-on-read', - 'write.update.mode'='merge-on-read', - 'write.merge.mode'='merge-on-read' - ) - """, - f""" - INSERT INTO {identifier} VALUES (10, 20), (10, 30), (10, 40) - """, - # Generate a positional delete - f""" - DELETE FROM {identifier} WHERE number = 30 - """, - ], - ) - - tbl = session_catalog.load_table(identifier) - - assert tbl.scan().count() == len(tbl.scan().to_arrow()) - - # Will rewrite a data file without the positional delete - tbl.delete(EqualTo("number", 40)) - assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "overwrite", "overwrite"] - assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [10], "number": [20]} - - assert tbl.scan().count() == len(tbl.scan().to_arrow()) - - run_spark_commands( - spark, - [ - f""" - INSERT INTO {identifier} VALUES (10, 20), (10, 30), (10, 40), (20, 30) - """, - # Generate a positional delete - f""" - DELETE FROM {identifier} WHERE number = 30 - """, - ], +@pytest.mark.integration +def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) -> None: + random.seed(876) + N = 1440 + d = { + "timestamp": pa.array([datetime(2023, 1, 1, 0, 0, 0) + timedelta(minutes=i) for i in range(N)]), + "category": pa.array([random.choice(["A", "B", "C"]) for _ in range(N)]), + "value": pa.array([random.gauss(0, 1) for _ in range(N)]), + } + data = pa.Table.from_pydict(d) + + try: + session_catalog.drop_table( + identifier="default.test_error_table", + ) + except NoSuchTableError: + pass + + table = session_catalog.create_table( + "default.test_error_table", + schema=data.schema, ) + with table.update_spec() as update: + update.add_field("timestamp", transform=HourTransform()) - tbl = session_catalog.load_table(identifier) - - assert tbl.scan().count() == len(tbl.scan().to_arrow()) + table.append(data) + assert table.scan().count() == len(table.scan().to_arrow()) + with table.update_spec() as update: + update.add_field("category", transform=IdentityTransform()) - run_spark_commands( - spark, - [ - # Generate a positional delete - f""" - DELETE FROM {identifier} WHERE number = 30 - """, - f""" - INSERT INTO {identifier} VALUES (10, 20), (10, 30), (10, 40), (20, 30) - """, - # Generate a positional delete - f""" - DELETE FROM {identifier} WHERE number = 20 - """, - ], + data_ = data.filter( + (pc.field("category") == "A") + & (pc.field("timestamp") >= datetime(2023, 1, 1, 0)) + & (pc.field("timestamp") < datetime(2023, 1, 1, 1)) ) - tbl = session_catalog.load_table(identifier) - - assert tbl.scan().count() == len(tbl.scan().to_arrow()) - - - filter_on_partition = "number_partitioned = 10" - scan_on_partition = tbl.scan(row_filter=filter_on_partition) - assert scan_on_partition.count() == len(scan_on_partition.to_arrow()) - - - filter = "number = 10" - scan = tbl.scan(row_filter=filter) - assert scan.count() == len(scan.to_arrow()) - + filter = And( + And( + GreaterThanOrEqual("timestamp", datetime(2023, 1, 1, 0).isoformat()), + LessThan("timestamp", datetime(2023, 1, 1, 1).isoformat()), + ), + EqualTo("category", "A"), + ) + # filter = GreaterThanOrEqual("timestamp", datetime(2023, 1, 1, 0).isoformat()) + # filter = LessThan("timestamp", datetime(2023, 1, 1, 1).isoformat()) + # filter = EqualTo("category", "A") + # assert table.scan().plan_files()[0].file.partition == {"category": "A"} + assert table.scan().count() == len(table.scan().to_arrow()) + assert table.scan(filter).count() == len(table.scan(filter).to_arrow()) + table.overwrite(df=data_, overwrite_filter=filter)