diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index d288e4f2f..e4d989609 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -27,8 +27,10 @@ import concurrent.futures import fnmatch +import functools import itertools import logging +import operator import os import re import uuid @@ -2542,36 +2544,6 @@ class _TablePartition: arrow_table_partition: pa.Table -def _get_table_partitions( - arrow_table: pa.Table, - partition_spec: PartitionSpec, - schema: Schema, - slice_instructions: list[dict[str, Any]], -) -> list[_TablePartition]: - sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"]) - - partition_fields = partition_spec.fields - - offsets = [inst["offset"] for inst in sorted_slice_instructions] - projected_and_filtered = { - partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name] - .take(offsets) - .to_pylist() - for partition_field in partition_fields - } - - table_partitions = [] - for idx, inst in enumerate(sorted_slice_instructions): - partition_slice = arrow_table.slice(**inst) - fieldvalues = [ - PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx]) - for partition_field in partition_fields - ] - partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) - table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) - return table_partitions - - def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]: """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. @@ -2594,42 +2566,46 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T We then retrieve the partition keys by offsets. And slice the arrow table by offsets and lengths of each partition. """ - partition_columns: List[Tuple[PartitionField, NestedField]] = [ - (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields - ] - partition_values_table = pa.table( - { - str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) - for partition, field in partition_columns - } - ) + # Assign unique names to columns where the partition transform has been applied + # to avoid conflicts + partition_fields = [f"_partition_{field.name}" for field in spec.fields] + + for partition, name in zip(spec.fields, partition_fields): + source_field = schema.find_field(partition.source_id) + arrow_table = arrow_table.append_column( + name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name]) + ) + + unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([]) + + table_partitions = [] + # TODO: As a next step, we could also play around with yielding instead of materializing the full list + for unique_partition in unique_partition_fields.to_pylist(): + partition_key = PartitionKey( + raw_partition_field_values=[ + PartitionFieldValue(field=field, value=unique_partition[name]) + for field, name in zip(spec.fields, partition_fields) + ], + partition_spec=spec, + schema=schema, + ) + filtered_table = arrow_table.filter( + functools.reduce( + operator.and_, + [ + pc.field(partition_field_name) == unique_partition[partition_field_name] + if unique_partition[partition_field_name] is not None + else pc.field(partition_field_name).is_null() + for field, partition_field_name in zip(spec.fields, partition_fields) + ], + ) + ) + filtered_table = filtered_table.drop_columns(partition_fields) - # Sort by partitions - sort_indices = pa.compute.sort_indices( - partition_values_table, - sort_keys=[(col, "ascending") for col in partition_values_table.column_names], - null_placement="at_end", - ).to_pylist() - arrow_table = arrow_table.take(sort_indices) - - # Get slice_instructions to group by partitions - partition_values_table = partition_values_table.take(sort_indices) - reversed_indices = pa.compute.sort_indices( - partition_values_table, - sort_keys=[(col, "descending") for col in partition_values_table.column_names], - null_placement="at_start", - ).to_pylist() - slice_instructions: List[Dict[str, Any]] = [] - last = len(reversed_indices) - reversed_indices_size = len(reversed_indices) - ptr = 0 - while ptr < reversed_indices_size: - group_size = last - reversed_indices[ptr] - offset = reversed_indices[ptr] - slice_instructions.append({"offset": offset, "length": group_size}) - last = reversed_indices[ptr] - ptr = ptr + group_size - - table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) + # The combine_chunks seems to be counter-intuitive to do, but it actually returns + # fresh buffers that don't interfere with each other when it is written out to file + table_partitions.append( + _TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks()) + ) return table_partitions diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 181377221..b3ab763bb 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -29,6 +29,7 @@ Optional, Tuple, TypeVar, + Union, ) from urllib.parse import quote_plus @@ -425,8 +426,13 @@ def _to_partition_representation(type: IcebergType, value: Any) -> Any: @_to_partition_representation.register(TimestampType) @_to_partition_representation.register(TimestamptzType) -def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]: - return datetime_to_micros(value) if value is not None else None +def _(type: IcebergType, value: Optional[Union[datetime, int]]) -> Optional[int]: + if value is None: + return None + elif isinstance(value, int): + return value + else: + return datetime_to_micros(value) @_to_partition_representation.register(DateType) diff --git a/tests/benchmark/test_benchmark.py b/tests/benchmark/test_benchmark.py new file mode 100644 index 000000000..290fb6242 --- /dev/null +++ b/tests/benchmark/test_benchmark.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import statistics +import timeit +import urllib + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from pyiceberg.transforms import DayTransform + + +@pytest.fixture(scope="session") +def taxi_dataset(tmp_path_factory: pytest.TempPathFactory) -> pa.Table: + """Reads the Taxi dataset to disk""" + taxi_dataset = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet" + taxi_dataset_dest = tmp_path_factory.mktemp("taxi_dataset") / "yellow_tripdata_2022-01.parquet" + urllib.request.urlretrieve(taxi_dataset, taxi_dataset_dest) + + return pq.read_table(taxi_dataset_dest) + + +def test_partitioned_write(tmp_path_factory: pytest.TempPathFactory, taxi_dataset: pa.Table) -> None: + """Tests writing to a partitioned table with something that would be close a production-like situation""" + from pyiceberg.catalog.sql import SqlCatalog + + warehouse_path = str(tmp_path_factory.mktemp("warehouse")) + catalog = SqlCatalog( + "default", + uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db", + warehouse=f"file://{warehouse_path}", + ) + + catalog.create_namespace("default") + + tbl = catalog.create_table("default.taxi_partitioned", schema=taxi_dataset.schema) + + with tbl.update_spec() as spec: + spec.add_field("tpep_pickup_datetime", DayTransform()) + + # Profiling can sometimes be handy as well + # with cProfile.Profile() as pr: + # tbl.append(taxi_dataset) + # + # pr.print_stats(sort=True) + + runs = [] + for run in range(5): + start_time = timeit.default_timer() + tbl.append(taxi_dataset) + elapsed = timeit.default_timer() - start_time + + print(f"Run {run} took: {elapsed}") + runs.append(elapsed) + + print(f"Average runtime of {round(statistics.mean(runs), 2)} seconds") diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 1e6ea1b79..807a504af 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -17,6 +17,7 @@ # pylint:disable=redefined-outer-name +import random from datetime import date from typing import Any, Set @@ -1126,3 +1127,25 @@ def test_append_multiple_partitions( """ ) assert files_df.count() == 6 + + +@pytest.mark.integration +def test_pyarrow_overflow(session_catalog: Catalog) -> None: + """Test what happens when the offset is beyond 32 bits""" + identifier = "default.arrow_table_overflow" + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + x = pa.array([random.randint(0, 999) for _ in range(30_000)]) + ta = pa.chunked_array([x] * 10_000) + y = ["fixed_string"] * 30_000 + tb = pa.chunked_array([y] * 10_000) + # Create pa.table + arrow_table = pa.table({"a": ta, "b": tb}) + + table = session_catalog.create_table(identifier, arrow_table.schema) + with table.update_spec() as update_spec: + update_spec.add_field("b", IdentityTransform(), "pb") + table.append(arrow_table)