diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 768878b068..a5e931f294 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -62,7 +62,7 @@ from pyiceberg.manifest import DataFile, ManifestFile, PartitionFieldSummary from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.typedef import EMPTY_DICT, L, StructProtocol +from pyiceberg.typedef import EMPTY_DICT, L, Record, StructProtocol from pyiceberg.types import ( DoubleType, FloatType, @@ -1731,3 +1731,214 @@ def _can_contain_nulls(self, field_id: int) -> bool: def _can_contain_nans(self, field_id: int) -> bool: return (nan_count := self.nan_counts.get(field_id)) is not None and nan_count > 0 + + +class ResidualVisitor(BoundBooleanExpressionVisitor[BooleanExpression], ABC): + schema: Schema + spec: PartitionSpec + case_sensitive: bool + + def __init__(self, schema: Schema, spec: PartitionSpec, case_sensitive: bool, expr: BooleanExpression): + self.schema = schema + self.spec = spec + self.case_sensitive = case_sensitive + self.expr = expr + + def eval(self, partition_data: Record) -> BooleanExpression: + self.struct = partition_data + return visit(self.expr, visitor=self) + + def visit_true(self) -> BooleanExpression: + return AlwaysTrue() + + def visit_false(self) -> BooleanExpression: + return AlwaysFalse() + + def visit_not(self, child_result: BooleanExpression) -> BooleanExpression: + return Not(child_result) + + def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return And(left_result, right_result) + + def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression: + return Or(left_result, right_result) + + def visit_is_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is None: + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_not_null(self, term: BoundTerm[L]) -> BooleanExpression: + if term.eval(self.struct) is not None: + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_is_nan(self, term: BoundTerm[L]) -> BooleanExpression: + val = term.eval(self.struct) + if val is None: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_nan(self, term: BoundTerm[L]) -> BooleanExpression: + val = term.eval(self.struct) + if val is not None: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) < literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) <= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) > literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) >= literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) == literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_equal(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if term.eval(self.struct) != literal.value: + return self.visit_true() + else: + return self.visit_false() + + def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: + if term.eval(self.struct) in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_not_in(self, term: BoundTerm[L], literals: Set[L]) -> BooleanExpression: + if term.eval(self.struct) not in literals: + return self.visit_true() + else: + return self.visit_false() + + def visit_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + eval_res = term.eval(self.struct) + if eval_res is not None and str(eval_res).startswith(str(literal.value)): + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_not_starts_with(self, term: BoundTerm[L], literal: Literal[L]) -> BooleanExpression: + if not self.visit_starts_with(term, literal): + return AlwaysTrue() + else: + return AlwaysFalse() + + def visit_bound_predicate(self, predicate: BoundPredicate[Any]) -> BooleanExpression: + """ + If there is no strict projection or if it evaluates to false, then return the predicate. + + Get the strict projection and inclusive projection of this predicate in partition data, + then use them to determine whether to return the original predicate. The strict projection + returns true iff the original predicate would have returned true, so the predicate can be + eliminated if the strict projection evaluates to true. Similarly the inclusive projection + returns false iff the original predicate would have returned false, so the predicate can + also be eliminated if the inclusive projection evaluates to false. + + """ + parts = self.spec.fields_by_source_id(predicate.term.ref().field.field_id) + if parts == []: + return predicate + + from pyiceberg.types import StructType + + def struct_to_schema(struct: StructType) -> Schema: + return Schema(*list(struct.fields)) + + for part in parts: + strict_projection = part.transform.strict_project(part.name, predicate) + strict_result = None + + if strict_projection is not None: + bound = strict_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound, BoundPredicate): + strict_result = super().visit_bound_predicate(bound) + else: + strict_result = bound + + if strict_result is not None and isinstance(strict_result, AlwaysTrue): + return AlwaysTrue() + + inclusive_projection = part.transform.project(part.name, predicate) + inclusive_result = None + if inclusive_projection is not None: + bound_inclusive = inclusive_projection.bind(struct_to_schema(self.spec.partition_type(self.schema))) + if isinstance(bound_inclusive, BoundPredicate): + # using predicate method specific to inclusive + inclusive_result = super().visit_bound_predicate(bound_inclusive) + else: + # if the result is not a predicate, then it must be a constant like alwaysTrue or + # alwaysFalse + inclusive_result = bound_inclusive + if inclusive_result is not None and isinstance(inclusive_result, AlwaysFalse): + return AlwaysFalse() + + return predicate + + def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression: + bound = predicate.bind(self.schema, case_sensitive=True) + + if isinstance(bound, BoundPredicate): + bound_residual = self.visit_bound_predicate(predicate=bound) + # if isinstance(bound_residual, BooleanExpression): + if bound_residual not in (AlwaysFalse(), AlwaysTrue()): + # replace inclusive original unbound predicate + return predicate + + # use the non-predicate residual (e.g. alwaysTrue) + return bound_residual + + # if binding didn't result in a Predicate, return the expression + return bound + + +class ResidualEvaluator(ResidualVisitor): + def residual_for(self, partition_data: Record) -> BooleanExpression: + return self.eval(partition_data) + + +class UnpartitionedResidualEvaluator(ResidualEvaluator): + # Finds the residuals for an Expression the partitions in the given PartitionSpec + def __init__(self, schema: Schema, expr: BooleanExpression): + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + + super().__init__(schema=schema, spec=UNPARTITIONED_PARTITION_SPEC, expr=expr, case_sensitive=False) + self.expr = expr + + def residual_for(self, partition_data: Record) -> BooleanExpression: + return self.expr + + +def residual_evaluator_of( + spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema +) -> ResidualEvaluator: + if len(spec.fields) != 0: + return ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) + else: + return UnpartitionedResidualEvaluator(schema=schema, expr=expr) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 057c02f260..7eda334383 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1356,6 +1356,9 @@ def filter(self: S, expr: Union[str, BooleanExpression]) -> S: def with_case_sensitive(self: S, case_sensitive: bool = True) -> S: return self.update(case_sensitive=case_sensitive) + @abstractmethod + def count(self) -> int: ... + class ScanTask(ABC): pass @@ -1369,6 +1372,7 @@ class FileScanTask(ScanTask): delete_files: Set[DataFile] start: int length: int + residual: BooleanExpression def __init__( self, @@ -1376,26 +1380,29 @@ def __init__( delete_files: Optional[Set[DataFile]] = None, start: Optional[int] = None, length: Optional[int] = None, + residual: Optional[BooleanExpression] = None, ) -> None: self.file = data_file self.delete_files = delete_files or set() self.start = start or 0 self.length = length or data_file.file_size_in_bytes + self.residual = residual # type: ignore def _open_manifest( io: FileIO, manifest: ManifestFile, partition_filter: Callable[[DataFile], bool], + residual_evaluator: Callable[[Record], BooleanExpression], metrics_evaluator: Callable[[DataFile], bool], -) -> List[ManifestEntry]: +) -> List[tuple[ManifestEntry, BooleanExpression]]: """Open a manifest file and return matching manifest entries. Returns: A list of ManifestEntry that matches the provided filters. """ return [ - manifest_entry + (manifest_entry, residual_evaluator(manifest_entry.data_file.partition)) for manifest_entry in manifest.fetch_manifest_entry(io, discard_deleted=True) if partition_filter(manifest_entry.data_file) and metrics_evaluator(manifest_entry.data_file) ] @@ -1462,6 +1469,27 @@ def _build_partition_evaluator(self, spec_id: int) -> Callable[[DataFile], bool] # shared instance across multiple threads. return lambda data_file: expression_evaluator(partition_schema, partition_expr, self.case_sensitive)(data_file.partition) + from pyiceberg.expressions.visitors import ResidualEvaluator + + def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], ResidualEvaluator]: + spec = self.table_metadata.specs()[spec_id] + + # The lambda created here is run in multiple threads. + # So we avoid creating _EvaluatorExpression methods bound to a single + # shared instance across multiple threads. + # return lambda data_file: (partition_schema, partition_expr, self.case_sensitive)(data_file.partition) + from pyiceberg.expressions.visitors import residual_evaluator_of + + # assert self.row_filter == False + return lambda datafile: ( + residual_evaluator_of( + spec=spec, + expr=self.row_filter, + case_sensitive=self.case_sensitive, + schema=self.table_metadata.schema(), + ) + ) + def _check_sequence_number(self, min_sequence_number: int, manifest: ManifestFile) -> bool: """Ensure that no manifests are loaded that contain deletes that are older than the data. @@ -1492,6 +1520,9 @@ def plan_files(self) -> Iterable[FileScanTask]: # the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator) + from pyiceberg.expressions.visitors import ResidualEvaluator + + residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator) manifests = [ manifest_file @@ -1503,6 +1534,7 @@ def plan_files(self) -> Iterable[FileScanTask]: # this filter depends on the partition spec used to write the manifest file partition_evaluators: Dict[int, Callable[[DataFile], bool]] = KeyDefaultDict(self._build_partition_evaluator) + metrics_evaluator = _InclusiveMetricsEvaluator( self.table_metadata.schema(), self.row_filter, @@ -1512,11 +1544,11 @@ def plan_files(self) -> Iterable[FileScanTask]: min_sequence_number = _min_sequence_number(manifests) - data_entries: List[ManifestEntry] = [] + data_entries: List[tuple[ManifestEntry, BooleanExpression]] = [] positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER) executor = ExecutorFactory.get_or_create() - for manifest_entry in chain( + for manifest_entry, residual in chain( *executor.map( lambda args: _open_manifest(*args), [ @@ -1524,6 +1556,7 @@ def plan_files(self) -> Iterable[FileScanTask]: self.io, manifest, partition_evaluators[manifest.partition_spec_id], + residual_evaluators[manifest.partition_spec_id], metrics_evaluator, ) for manifest in manifests @@ -1533,7 +1566,7 @@ def plan_files(self) -> Iterable[FileScanTask]: ): data_file = manifest_entry.data_file if data_file.content == DataFileContent.DATA: - data_entries.append(manifest_entry) + data_entries.append((manifest_entry, residual)) elif data_file.content == DataFileContent.POSITION_DELETES: positional_delete_entries.add(manifest_entry) elif data_file.content == DataFileContent.EQUALITY_DELETES: @@ -1548,8 +1581,9 @@ def plan_files(self) -> Iterable[FileScanTask]: data_entry, positional_delete_entries, ), + residual=residual, ) - for data_entry in data_entries + for data_entry, residual in data_entries ] def to_arrow(self) -> pa.Table: @@ -1622,6 +1656,43 @@ def to_ray(self) -> ray.data.dataset.Dataset: return ray.data.from_arrow(self.to_arrow()) + def count(self) -> int: + # Usage: Calculates the total number of records in a Scan that haven't had positional deletes. + res = 0 + # every task is a FileScanTask + tasks = self.plan_files() + + for task in tasks: + # task.residual is a Boolean Expression if the filter condition is fully satisfied by the + # partition value and task.delete_files represents that positional delete haven't been merged yet + # hence those files have to read as a pyarrow table applying the filter and deletes + if task.residual == AlwaysTrue() and not len(task.delete_files): + # Every File has a metadata stat that stores the file record count + res += task.file.record_count + else: + from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow + + arrow_scan = ArrowScan( + table_metadata=self.table_metadata, + io=self.io, + projected_schema=self.projection(), + row_filter=self.row_filter, + case_sensitive=self.case_sensitive, + limit=self.limit, + ) + if task.file.file_size_in_bytes > 512 * 1024 * 1024: + target_schema = schema_to_pyarrow(self.projection()) + batches = arrow_scan.to_record_batches([task]) + from pyarrow import RecordBatchReader + + reader = RecordBatchReader.from_batches(target_schema, batches) + for batch in reader: + res += batch.num_rows + else: + tbl = arrow_scan.to_table([task]) + res += len(tbl) + return res + @dataclass(frozen=True) class WriteTask: diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py new file mode 100644 index 0000000000..60588d8d9c --- /dev/null +++ b/tests/expressions/test_residual_evaluator.py @@ -0,0 +1,251 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +import pytest + +from pyiceberg.expressions import ( + AlwaysFalse, + AlwaysTrue, + And, + EqualTo, + GreaterThan, + GreaterThanOrEqual, + In, + IsNaN, + IsNull, + LessThan, + NotIn, + NotNaN, + NotNull, + NotStartsWith, + Or, + StartsWith, +) +from pyiceberg.expressions.literals import literal +from pyiceberg.expressions.visitors import residual_evaluator_of +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.transforms import DayTransform, IdentityTransform +from pyiceberg.typedef import Record +from pyiceberg.types import DoubleType, FloatType, IntegerType, NestedField, TimestampType + + +def test_identity_transform_residual() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = Or( + Or( + And(EqualTo("dateint", 20170815), LessThan("hour", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), + ), + And(EqualTo("dateint", 20170801), GreaterThan("hour", 11)), + ) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + # assert residual == True + assert isinstance(residual, LessThan) + assert residual.term.name == "hour" # type: ignore + # assert residual.term.field.name == 'hour' + assert residual.literal.value == 12 + assert type(residual) is LessThan + + residual = res_eval.residual_for(Record(dateint=20170801)) + + # assert isinstance(residual, UnboundPredicate) + from pyiceberg.expressions import LiteralPredicate + + assert isinstance(residual, LiteralPredicate) + # assert isinstance(residual, GreaterThan) + assert residual.term.name == "hour" # type: ignore + # assert residual.term. + assert residual.literal.value == 11 # type :ignore + # assert type(residual) == BoundGreaterThan + + residual = res_eval.residual_for(Record(dateint=20170812)) + + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170817)) + + assert residual == AlwaysFalse() + + +def test_case_insensitive_identity_transform_residuals() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = Or( + Or( + And(EqualTo("DATEINT", 20170815), LessThan("HOUR", 12)), + And(LessThan("dateint", 20170815), GreaterThan("dateint", 20170801)), + ), + And(EqualTo("Dateint", 20170801), GreaterThan("hOUr", 11)), + ) + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + with pytest.raises(ValueError) as e: + res_eval.residual_for(Record(dateint=20170815)) + assert "Could not find field with name DATEINT, case_sensitive=True" in str(e.value) + + +def test_unpartitioned_residuals() -> None: + expressions = [ + AlwaysTrue(), + AlwaysFalse(), + LessThan("a", 5), + GreaterThanOrEqual("b", 16), + NotNull("c"), + IsNull("d"), + In("e", [1, 2, 3]), + NotIn("f", [1, 2, 3]), + NotNaN("g"), + IsNaN("h"), + StartsWith("data", "abcd"), + NotStartsWith("data", "abcd"), + ] + + schema = Schema( + NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()), NestedField(52, "a", IntegerType()) + ) + for expr in expressions: + from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC + + residual_evaluator = residual_evaluator_of(UNPARTITIONED_PARTITION_SPEC, expr, True, schema=schema) + assert residual_evaluator.residual_for(Record()) == expr + + +def test_in() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = In("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20170815)) + + assert residual == AlwaysTrue() + + +def test_in_timestamp() -> None: + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + ts_day = day(date_20191201) # type: ignore + + pred = In("ts", [date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + + residual = res_eval.residual_for(Record(ts_day + 3)) # type: ignore + assert residual == AlwaysFalse() + + +def test_not_in() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + + predicate = NotIn("dateint", [20170815, 20170816, 20170817]) + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(dateint=20180815)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(dateint=20170815)) + assert residual == AlwaysFalse() + + +def test_is_nan() -> None: + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "hour", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) + + predicate = IsNaN("double") + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysTrue() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysFalse() + + +def test_is_not_nan() -> None: + schema = Schema(NestedField(50, "double", DoubleType()), NestedField(51, "float", FloatType())) + + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "double_part")) + + predicate = NotNaN("double") + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + spec = PartitionSpec(PartitionField(51, 1051, IdentityTransform(), "float_part")) + + predicate = NotNaN("float") + + res_eval = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(double=None)) + assert residual == AlwaysFalse() + + residual = res_eval.residual_for(Record(double=2)) + assert residual == AlwaysTrue() + + +def test_not_in_timestamp() -> None: + schema = Schema(NestedField(50, "ts", TimestampType()), NestedField(51, "dateint", IntegerType())) + + spec = PartitionSpec(PartitionField(50, 1000, DayTransform(), "ts_part")) + + date_20191201 = literal("2019-12-01T00:00:00").to(TimestampType()).value + date_20191202 = literal("2019-12-02T00:00:00").to(TimestampType()).value + + day = DayTransform().transform(TimestampType()) + ts_day = day(date_20191201) # type: ignore + + pred = NotIn("ts", [date_20191202, date_20191201]) + + res_eval = residual_evaluator_of(spec=spec, expr=pred, case_sensitive=True, schema=schema) + + residual = res_eval.residual_for(Record(ts_day)) + assert residual == pred + ts_day += 3 # type: ignore + residual = res_eval.residual_for(Record(ts_day)) + assert residual == AlwaysTrue() diff --git a/tests/integration/test_delete_count.py b/tests/integration/test_delete_count.py new file mode 100644 index 0000000000..0ba9d2d6da --- /dev/null +++ b/tests/integration/test_delete_count.py @@ -0,0 +1,159 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +import random +from datetime import datetime, timedelta +from typing import Generator, List + +import pyarrow as pa +import pytest +from pyarrow import compute as pc +from pyspark.sql import SparkSession + +from pyiceberg.catalog import Catalog +from pyiceberg.catalog.rest import RestCatalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, LessThan +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.transforms import HourTransform, IdentityTransform +from pyiceberg.types import LongType, NestedField, StringType + + +def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None: + for sql in sqls: + spark.sql(sql) + + +@pytest.fixture() +def test_table(session_catalog: RestCatalog) -> Generator[Table, None, None]: + identifier = "default.__test_table" + arrow_table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5]), pa.array(["a", "b", "c", "d", "e"])], names=["idx", "value"]) + test_table = session_catalog.create_table( + identifier, + schema=Schema( + NestedField(1, "idx", LongType()), + NestedField(2, "value", StringType()), + ), + ) + test_table.append(arrow_table) + + yield test_table + + session_catalog.drop_table(identifier) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_partitioned_table_delete_full_file(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None: + identifier = "default.table_partitioned_delete" + run_spark_commands( + spark, + [ + f"DROP TABLE IF EXISTS {identifier}", + f""" + CREATE TABLE {identifier} ( + number_partitioned long, + number long + ) + USING iceberg + PARTITIONED BY (number_partitioned) + TBLPROPERTIES('format-version' = {format_version}) + """, + f""" + INSERT INTO {identifier} VALUES (10, 20), (10, 30) + """, + f""" + INSERT INTO {identifier} VALUES (11, 20), (11, 30) + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + tbl.delete(EqualTo("number_partitioned", 10)) + + # No overwrite operation + assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "append", "delete"] + assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 11], "number": [20, 30]} + + assert tbl.scan().count() == len(tbl.scan().to_arrow()) + filter = And(EqualTo("number_partitioned", 11), GreaterThanOrEqual("number", 5)) + assert tbl.scan(filter).count() == len(tbl.scan(filter).to_arrow()) + N = 10 + d = { + "number_partitioned": pa.array([i * 10 for i in range(N)]), + "number": pa.array([random.choice([10, 20, 40]) for _ in range(N)]), + } + with tbl.update_spec() as update: + update.add_field("number", transform=IdentityTransform()) + + data = pa.Table.from_pydict(d) + + tbl.overwrite(df=data, overwrite_filter=filter) + + +@pytest.mark.integration +def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) -> None: + random.seed(876) + N = 1440 + d = { + "timestamp": pa.array([datetime(2023, 1, 1, 0, 0, 0) + timedelta(minutes=i) for i in range(N)]), + "category": pa.array([random.choice(["A", "B", "C"]) for _ in range(N)]), + "value": pa.array([random.gauss(0, 1) for _ in range(N)]), + } + data = pa.Table.from_pydict(d) + + try: + session_catalog.drop_table( + identifier="default.test_error_table", + ) + except NoSuchTableError: + pass + + table = session_catalog.create_table( + "default.test_error_table", + schema=data.schema, + ) + with table.update_spec() as update: + update.add_field("timestamp", transform=HourTransform()) + + table.append(data) + assert table.scan().count() == len(table.scan().to_arrow()) + + with table.update_spec() as update: + update.add_field("category", transform=IdentityTransform()) + + data_ = data.filter( + (pc.field("category") == "A") + & (pc.field("timestamp") >= datetime(2023, 1, 1, 0)) + & (pc.field("timestamp") < datetime(2023, 1, 1, 1)) + ) + + filter = And( + And( + GreaterThanOrEqual("timestamp", datetime(2023, 1, 1, 0).isoformat()), + LessThan("timestamp", datetime(2023, 1, 1, 1).isoformat()), + ), + EqualTo("category", "A"), + ) + # filter = GreaterThanOrEqual("timestamp", datetime(2023, 1, 1, 0).isoformat()) + # filter = LessThan("timestamp", datetime(2023, 1, 1, 1).isoformat()) + # filter = EqualTo("category", "A") + # assert table.scan().plan_files()[0].file.partition == {"category": "A"} + assert table.scan().count() == len(table.scan().to_arrow()) + assert table.scan(filter).count() == len(table.scan(filter).to_arrow()) + table.overwrite(df=data_, overwrite_filter=filter)