From 912fb828341fa59066f2f6d37a9dc86f4d4af564 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 5 Apr 2024 00:58:54 +0200 Subject: [PATCH] closes writers on exceptions, passes metrics on exceptions, fixes some edge cases with empty arrow files --- dlt/common/data_writers/buffered.py | 35 +++++--- dlt/common/data_writers/writers.py | 29 +++++-- dlt/common/storages/data_item_storage.py | 8 +- dlt/destinations/impl/bigquery/bigquery.py | 11 +-- .../impl/bigquery/bigquery_adapter.py | 6 +- dlt/extract/extract.py | 85 +++++++++++-------- dlt/extract/extractors.py | 5 +- dlt/extract/storage.py | 4 +- dlt/normalize/exceptions.py | 6 +- dlt/normalize/items_normalizers.py | 7 +- dlt/normalize/normalize.py | 33 ++++--- dlt/pipeline/pipeline.py | 4 +- tests/cases.py | 6 +- tests/common/storages/utils.py | 1 + tests/extract/test_extract.py | 69 ++++++++++++++- tests/libs/test_arrow_csv_writer.py | 6 +- tests/libs/test_parquet_writer.py | 3 +- tests/load/pipeline/test_arrow_loading.py | 2 +- .../load/pipeline/test_filesystem_pipeline.py | 2 +- tests/load/pipeline/test_postgres.py | 3 +- tests/load/utils.py | 1 + tests/normalize/test_normalize.py | 16 +++- tests/pipeline/test_arrow_sources.py | 67 ++++++++------- 23 files changed, 275 insertions(+), 134 deletions(-) diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 1db18b065e..e358919c7a 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -164,9 +164,10 @@ def import_file(self, file_path: str, metrics: DataWriterMetrics) -> DataWriterM self._rotate_file() return metrics - def close(self) -> None: + def close(self, skip_flush: bool = False) -> None: + """Flushes the data, writes footer (skip_flush is True), collects metrics and closes the underlying file.""" self._ensure_open() - self._flush_and_close_file() + self._flush_and_close_file(skip_flush=skip_flush) self._closed = True @property @@ -177,7 +178,8 @@ def __enter__(self) -> "BufferedDataWriter[TWriter]": return self def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None: - self.close() + # skip flush if we had exception + self.close(skip_flush=exc_val is not None) def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: metrics = self._flush_and_close_file(allow_empty_file) @@ -188,7 +190,7 @@ def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: return metrics def _flush_items(self, allow_empty_file: bool = False) -> None: - if self._buffered_items_count > 0 or allow_empty_file: + if self._buffered_items or allow_empty_file: # we only open a writer when there are any items in the buffer and first flush is requested if not self._writer: # create new writer and write header @@ -205,15 +207,22 @@ def _flush_items(self, allow_empty_file: bool = False) -> None: self._buffered_items.clear() self._buffered_items_count = 0 - def _flush_and_close_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: - # if any buffered items exist, flush them - self._flush_items(allow_empty_file) - # if writer exists then close it - if not self._writer: - return None - # write the footer of a file - self._writer.write_footer() - self._file.flush() + def _flush_and_close_file( + self, allow_empty_file: bool = False, skip_flush: bool = False + ) -> DataWriterMetrics: + if not skip_flush: + # if any buffered items exist, flush them + self._flush_items(allow_empty_file) + # if writer exists then close it + if not self._writer: + return None + # write the footer of a file + self._writer.write_footer() + self._file.flush() + else: + if not self._writer: + return None + self._writer.close() # add file written to the list so we can commit all the files later metrics = DataWriterMetrics( self._file_name, diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 67e5466d39..9936a6844d 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -84,6 +84,9 @@ def write_data(self, rows: Sequence[Any]) -> None: def write_footer(self) -> None: # noqa pass + def close(self) -> None: # noqa + pass + def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None: self.write_header(columns_schema) self.write_data(rows) @@ -321,9 +324,10 @@ def write_data(self, rows: Sequence[Any]) -> None: # Write self.writer.write_table(table, row_group_size=self.parquet_row_group_size) - def write_footer(self) -> None: - self.writer.close() - self.writer = None + def close(self) -> None: # noqa + if self.writer: + self.writer.close() + self.writer = None @classmethod def writer_spec(cls) -> FileWriterSpec: @@ -362,10 +366,9 @@ def write_data(self, rows: Sequence[Any]) -> None: # count rows that got written self.items_count += sum(len(row) for row in rows) - def write_footer(self) -> None: - if self.writer is None: - self.writer = None - self._first_schema = None + def close(self) -> None: + self.writer = None + self._first_schema = None @classmethod def writer_spec(cls) -> FileWriterSpec: @@ -405,6 +408,9 @@ def write_footer(self) -> None: raise NotImplementedError("Arrow Writer does not support writing empty files") return super().write_footer() + def close(self) -> None: + return super().close() + @classmethod def writer_spec(cls) -> FileWriterSpec: return FileWriterSpec( @@ -488,10 +494,15 @@ def write_footer(self) -> None: # write empty file self._f.write( self.delimiter.join( - [col["name"].encode("utf-8") for col in self._columns_schema.values()] + [ + b'"' + col["name"].encode("utf-8") + b'"' + for col in self._columns_schema.values() + ] ) ) - else: + + def close(self) -> None: + if self.writer: self.writer.close() self.writer = None self._first_schema = None diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 5b1e360789..ab15c3ad5b 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -72,15 +72,17 @@ def import_items_file( writer = self._get_writer(load_id, schema_name, table_name) return writer.import_file(file_path, metrics) - def close_writers(self, load_id: str) -> None: - # flush and close all files + def close_writers(self, load_id: str, skip_flush: bool = False) -> None: + """Flush, write footers (skip_flush), write metrics and close files in all + writers belonging to `load_id` package + """ for name, writer in self.buffered_writers.items(): if name.startswith(load_id) and not writer.closed: logger.debug( f"Closing writer for {name} with file {writer._file} and actual name" f" {writer._file_name}" ) - writer.close() + writer.close(skip_flush=skip_flush) def closed_files(self, load_id: str) -> List[DataWriterMetrics]: """Return metrics for all fully processed (closed) files""" diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 279917d3a0..b2e53f9734 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -232,10 +232,9 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if insert_api == "streaming": if table["write_disposition"] != "append": raise DestinationTerminalException( - ( - "BigQuery streaming insert can only be used with `append` write_disposition, while " - f'the given resource has `{table["write_disposition"]}`.' - ) + "BigQuery streaming insert can only be used with `append`" + " write_disposition, while the given resource has" + f" `{table['write_disposition']}`." ) if file_path.endswith(".jsonl"): job_cls = DestinationJsonlLoadJob @@ -364,7 +363,9 @@ def prepare_load_table( def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(column["name"]) - column_def_sql = f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" + column_def_sql = ( + f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" + ) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False): diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 8943b0da79..6b3ef32b0f 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -153,10 +153,8 @@ def bigquery_adapter( if insert_api is not None: if insert_api == "streaming" and data.write_disposition != "append": raise ValueError( - ( - "BigQuery streaming insert can only be used with `append` write_disposition, while " - f"the given resource has `{data.write_disposition}`." - ) + "BigQuery streaming insert can only be used with `append` write_disposition, while " + f"the given resource has `{data.write_disposition}`." ) additional_table_hints |= {"x-insert-api": insert_api} # type: ignore[operator] diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 75f22bb802..02dd06eaf3 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -2,7 +2,7 @@ from collections.abc import Sequence as C_Sequence from copy import copy import itertools -from typing import List, Dict, Any +from typing import Iterator, List, Dict, Any import yaml from dlt.common.configuration.container import Container @@ -304,41 +304,58 @@ def _extract_single_source( load_id, self.extract_storage.item_storages["arrow"], schema, collector=collector ), } - + # make sure we close storage on exception with collector(f"Extract {source.name}"): - self._step_info_start_load_id(load_id) - # yield from all selected pipes - with PipeIterator.from_pipes( - source.resources.selected_pipes, - max_parallel_items=max_parallel_items, - workers=workers, - futures_poll_interval=futures_poll_interval, - ) as pipes: - left_gens = total_gens = len(pipes._sources) - collector.update("Resources", 0, total_gens) - for pipe_item in pipes: - curr_gens = len(pipes._sources) - if left_gens > curr_gens: - delta = left_gens - curr_gens - left_gens -= delta - collector.update("Resources", delta) - signals.raise_if_signalled() - resource = source.resources[pipe_item.pipe.name] - item_format = get_data_item_format(pipe_item.item) - extractors[item_format].write_items(resource, pipe_item.item, pipe_item.meta) - - self._write_empty_files(source, extractors) - if left_gens > 0: - # go to 100% - collector.update("Resources", left_gens) - - # flush all buffered writers + with self.manage_writers(load_id, source): + # yield from all selected pipes + with PipeIterator.from_pipes( + source.resources.selected_pipes, + max_parallel_items=max_parallel_items, + workers=workers, + futures_poll_interval=futures_poll_interval, + ) as pipes: + left_gens = total_gens = len(pipes._sources) + collector.update("Resources", 0, total_gens) + for pipe_item in pipes: + curr_gens = len(pipes._sources) + if left_gens > curr_gens: + delta = left_gens - curr_gens + left_gens -= delta + collector.update("Resources", delta) + signals.raise_if_signalled() + resource = source.resources[pipe_item.pipe.name] + item_format = get_data_item_format(pipe_item.item) + extractors[item_format].write_items( + resource, pipe_item.item, pipe_item.meta + ) + + self._write_empty_files(source, extractors) + if left_gens > 0: + # go to 100% + collector.update("Resources", left_gens) + + @contextlib.contextmanager + def manage_writers(self, load_id: str, source: DltSource) -> Iterator[ExtractStorage]: + self._step_info_start_load_id(load_id) + # self.current_source = source + try: + yield self.extract_storage + except Exception: + # kill writers without flushing the content + self.extract_storage.close_writers(load_id, skip_flush=True) + raise + else: self.extract_storage.close_writers(load_id) - # gather metrics - self._step_info_complete_load_id(load_id, self._compute_metrics(load_id, source)) - # remove the metrics of files processed in this extract run - # NOTE: there may be more than one extract run per load id: ie. the resource and then dlt state - self.extract_storage.remove_closed_files(load_id) + finally: + # gather metrics when storage is closed + self.gather_metrics(load_id, source) + + def gather_metrics(self, load_id: str, source: DltSource) -> None: + # gather metrics + self._step_info_complete_load_id(load_id, self._compute_metrics(load_id, source)) + # remove the metrics of files processed in this extract run + # NOTE: there may be more than one extract run per load id: ie. the resource and then dlt state + self.extract_storage.remove_closed_files(load_id) def extract( self, diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index c4b7653164..421250951e 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -122,10 +122,11 @@ def _write_item( self.load_id, self.schema.name, table_name, items, columns ) self.collector.update(table_name, inc=new_rows_count) - if new_rows_count > 0: + # if there were rows or item was empty arrow table + if new_rows_count > 0 or self.__class__ is ArrowExtractor: self.resources_with_items.add(resource_name) else: - if isinstance(items, MaterializedEmptyList) or self.__class__ is ArrowExtractor: + if isinstance(items, MaterializedEmptyList): self.resources_with_empty.add(resource_name) def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> None: diff --git a/dlt/extract/storage.py b/dlt/extract/storage.py index b76822a4f2..3e01a020ba 100644 --- a/dlt/extract/storage.py +++ b/dlt/extract/storage.py @@ -74,9 +74,9 @@ def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True self.new_packages.save_schema(load_id, schema) return load_id - def close_writers(self, load_id: str) -> None: + def close_writers(self, load_id: str, skip_flush: bool = False) -> None: for storage in self.item_storages.values(): - storage.close_writers(load_id) + storage.close_writers(load_id, skip_flush=skip_flush) def closed_files(self, load_id: str) -> List[DataWriterMetrics]: files = [] diff --git a/dlt/normalize/exceptions.py b/dlt/normalize/exceptions.py index a172196899..7bc305fcbe 100644 --- a/dlt/normalize/exceptions.py +++ b/dlt/normalize/exceptions.py @@ -1,3 +1,4 @@ +from typing import Any, List from dlt.common.exceptions import DltException @@ -7,10 +8,13 @@ def __init__(self, msg: str) -> None: class NormalizeJobFailed(NormalizeException): - def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: + def __init__( + self, load_id: str, job_id: str, failed_message: str, writer_metrics: List[Any] + ) -> None: self.load_id = load_id self.job_id = job_id self.failed_message = failed_message + self.writer_metrics = writer_metrics super().__init__( f"Job for {job_id} failed terminally in load {load_id} with message {failed_message}." ) diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index bf4073ddbf..1e4e55effd 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -283,7 +283,7 @@ def _write_with_dlt_columns( items_count = 0 columns_schema = schema.get_table_columns(root_table_name) # if we use adapter to convert arrow to dicts, then normalization is not necessary - may_normalize = not issubclass(self.item_storage.writer_cls, ArrowToObjectAdapter) + is_native_arrow_writer = not issubclass(self.item_storage.writer_cls, ArrowToObjectAdapter) should_normalize: bool = None with self.normalize_storage.extracted_packages.storage.open_file( extracted_items_file, "rb" @@ -293,7 +293,7 @@ def _write_with_dlt_columns( ): items_count += batch.num_rows # we may need to normalize - if may_normalize and should_normalize is None: + if is_native_arrow_writer and should_normalize is None: should_normalize, _, _, _ = pyarrow.should_normalize_arrow_schema( batch.schema, columns_schema, schema.naming ) @@ -315,7 +315,8 @@ def _write_with_dlt_columns( batch, columns_schema, ) - if items_count == 0: + # TODO: better to check if anything is in the buffer and skip writing file + if items_count == 0 and not is_native_arrow_writer: self.item_storage.write_empty_items_file( load_id, schema.name, diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 28c2c81571..47d0cd9898 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -142,6 +142,19 @@ def _get_items_normalizer(item_format: TDataItemFormat) -> ItemsNormalizer: ) return norm + def _gather_metrics_and_close(skip_flush: bool) -> List[DataWriterMetrics]: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=skip_flush) + + writer_metrics: List[DataWriterMetrics] = [] + for normalizer in item_normalizers.values(): + norm_metrics = normalizer.item_storage.closed_files(load_id) + writer_metrics.extend(norm_metrics) + + for normalizer in item_normalizers.values(): + normalizer.item_storage.remove_closed_files(load_id) + return writer_metrics + parsed_file_name: ParsedLoadJobFileName = None try: root_tables: Set[str] = set() @@ -165,15 +178,11 @@ def _get_items_normalizer(item_format: TDataItemFormat) -> ItemsNormalizer: logger.debug(f"Processed file {extracted_items_file}") except Exception as exc: job_id = parsed_file_name.job_id() if parsed_file_name else "" - raise NormalizeJobFailed(load_id, job_id, str(exc)) from exc - finally: - for normalizer in item_normalizers.values(): - normalizer.item_storage.close_writers(load_id) + writer_metrics = _gather_metrics_and_close(skip_flush=True) + raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) from exc + else: + writer_metrics = _gather_metrics_and_close(skip_flush=False) - writer_metrics: List[DataWriterMetrics] = [] - for normalizer in item_normalizers.values(): - norm_metrics = normalizer.item_storage.closed_files(load_id) - writer_metrics.extend(norm_metrics) logger.info(f"Processed all items in {len(extracted_items_files)} files") return TWorkerRV(schema_updates, writer_metrics) @@ -233,9 +242,11 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW for task in list(tasks): pending, params = task if pending.done(): - result: TWorkerRV = ( - pending.result() - ) # Exception in task (if any) is raised here + # collect metrics from the exception (if any) + if isinstance(pending.exception(), NormalizeJobFailed): + summary.file_metrics.extend(pending.exception().writer_metrics) # type: ignore[attr-defined] + # Exception in task (if any) is raised here + result: TWorkerRV = pending.result() try: # gather schema from all manifests, validate consistency and combine self.update_table(schema, result[0]) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index b0d04dfbe8..683251c2a8 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -436,11 +436,13 @@ def extract( extract_step.commit_packages() return self._get_step_info(extract_step) except Exception as exc: + # emit step info step_info = self._get_step_info(extract_step) + current_load_id = step_info.loads_ids[-1] if len(step_info.loads_ids) > 0 else None raise PipelineStepFailed( self, "extract", - extract_step.current_load_id, + current_load_id, exc, step_info, ) from exc diff --git a/tests/cases.py b/tests/cases.py index 9a0213d837..b598f1169e 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -294,7 +294,7 @@ def arrow_table_all_data_types( include_name_clash: bool = False, num_rows: int = 3, tz="UTC", -) -> Tuple[Any, List[Dict[str, Any]]]: +) -> Tuple[Any, List[Dict[str, Any]], Dict[str, List[Any]]]: """Create an arrow object or pandas dataframe with all supported data types. Returns the table and its records in python format @@ -351,14 +351,14 @@ def arrow_table_all_data_types( .drop(columns=["null"]) .to_dict("records") ) - return arrow_format_from_pandas(df, object_format), rows + return arrow_format_from_pandas(df, object_format), rows, data def prepare_shuffled_tables() -> Tuple[Any, Any, Any]: from dlt.common.libs.pyarrow import remove_columns from dlt.common.libs.pyarrow import pyarrow as pa - table, _ = arrow_table_all_data_types( + table, _, _ = arrow_table_all_data_types( "table", include_json=False, include_not_normalized_name=False, diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index e500f149ed..13ec253e2f 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -123,6 +123,7 @@ def write_temp_job_file( item_storage.writer_spec.file_format, item_storage.writer_spec.data_item_format, f ) writer.write_all(table, rows) + writer.close() return Path(file_name).name diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py index 1879eaa9eb..9620e7fdfb 100644 --- a/tests/extract/test_extract.py +++ b/tests/extract/test_extract.py @@ -11,12 +11,12 @@ from dlt.common.storages.schema_storage import SchemaStorage from dlt.extract import DltResource, DltSource -from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints +from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints, ResourceExtractionError from dlt.extract.extract import ExtractStorage, Extract from dlt.extract.hints import make_hints from dlt.extract.items import TableNameMeta -from tests.utils import clean_test_storage, TEST_STORAGE_ROOT +from tests.utils import MockPipeline, clean_test_storage, TEST_STORAGE_ROOT from tests.extract.utils import expect_extracted_file @@ -211,6 +211,71 @@ def with_table_hints(): extract_step.extract(source, 20, 1) +def test_extract_metrics_on_exception_no_flush(extract_step: Extract) -> None: + @dlt.resource + def letters(): + # extract 7 items + yield from "ABCDEFG" + # then fail + raise RuntimeError() + yield from "HI" + + source = DltSource(dlt.Schema("letters"), "module", [letters]) + with pytest.raises(ResourceExtractionError): + extract_step.extract(source, 20, 1) + step_info = extract_step.get_step_info(MockPipeline("buba", first_run=False)) # type: ignore[abstract] + # no jobs were created + assert len(step_info.load_packages[0].jobs["new_jobs"]) == 0 + # make sure all writers are closed but not yet removed + current_load_id = step_info.loads_ids[-1] if len(step_info.loads_ids) > 0 else None + # get buffered writers + writers = extract_step.extract_storage.item_storages["object"].buffered_writers + assert len(writers) == 1 + for name, writer in writers.items(): + assert name.startswith(current_load_id) + assert writer._file is None + + +def test_extract_metrics_on_exception_without_flush(extract_step: Extract) -> None: + @dlt.resource + def letters(): + # extract 7 items + yield from "ABCDEFG" + # then fail + raise RuntimeError() + yield from "HI" + + # flush buffer + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "4" + source = DltSource(dlt.Schema("letters"), "module", [letters]) + with pytest.raises(ResourceExtractionError): + extract_step.extract(source, 20, 1) + step_info = extract_step.get_step_info(MockPipeline("buba", first_run=False)) # type: ignore[abstract] + # one job created because the file was flushed + jobs = step_info.load_packages[0].jobs["new_jobs"] + # print(jobs[0].job_file_info.job_id()) + assert len(jobs) == 1 + current_load_id = step_info.loads_ids[-1] if len(step_info.loads_ids) > 0 else None + # 7 items were extracted + assert ( + step_info.metrics[current_load_id][0]["job_metrics"][ + jobs[0].job_file_info.job_id() + ].items_count + == 4 + ) + # get buffered writers + writers = extract_step.extract_storage.item_storages["object"].buffered_writers + assert len(writers) == 1 + for name, writer in writers.items(): + assert name.startswith(current_load_id) + assert writer._file is None + + +def test_extract_empty_metrics(extract_step: Extract) -> None: + step_info = extract_step.get_step_info(MockPipeline("buba", first_run=False)) # type: ignore[abstract] + assert step_info.load_packages == step_info.loads_ids == [] + + # def test_extract_pipe_from_unknown_resource(): # pass diff --git a/tests/libs/test_arrow_csv_writer.py b/tests/libs/test_arrow_csv_writer.py index 91038f01c4..85a15cc169 100644 --- a/tests/libs/test_arrow_csv_writer.py +++ b/tests/libs/test_arrow_csv_writer.py @@ -18,7 +18,7 @@ def test_csv_writer_all_data_fields() -> None: - data = TABLE_ROW_ALL_DATA_TYPES_DATETIMES + data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) # write parquet and read it with get_writer(ParquetDataWriter) as pq_writer: @@ -94,7 +94,7 @@ def test_csv_writer_all_data_fields() -> None: def test_non_utf8_binary() -> None: - data = TABLE_ROW_ALL_DATA_TYPES_DATETIMES + data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) data["col7"] += b"\x8e" # type: ignore[operator] # write parquet and read it @@ -110,7 +110,7 @@ def test_non_utf8_binary() -> None: def test_arrow_struct() -> None: - item, _ = arrow_table_all_data_types("table", include_json=True, include_time=False) + item, _, _ = arrow_table_all_data_types("table", include_json=True, include_time=False) with pytest.raises(InvalidDataItem): with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: writer.write_data_item(item, TABLE_UPDATE_COLUMNS_SCHEMA) diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index 3b4239f2b0..786617ef55 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -128,8 +128,7 @@ def test_parquet_writer_all_data_fields() -> None: assert actual == value assert table.schema.field("col1_precision").type == pa.int16() - # flavor=spark only writes ns precision timestamp, so this is expected - assert table.schema.field("col4_precision").type == pa.timestamp("ns") + assert table.schema.field("col4_precision").type == pa.timestamp("ms", tz="UTC") assert table.schema.field("col5_precision").type == pa.string() assert table.schema.field("col6_precision").type == pa.decimal128(6, 2) assert table.schema.field("col7_precision").type == pa.binary(19) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 2c649c18de..98f44b1c8a 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -52,7 +52,7 @@ def test_load_arrow_item( destination_config.destination == "databricks" and destination_config.file_format == "jsonl" ) - item, records = arrow_table_all_data_types( + item, records, _ = arrow_table_all_data_types( item_type, include_json=False, include_time=include_time, diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 6d33b477fc..8401f9d3af 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -109,7 +109,7 @@ def test_pipeline_csv_filesystem_destination() -> None: dataset_name="parquet_test_" + uniq_id(), ) - item, _ = arrow_table_all_data_types("table", include_json=False, include_time=True) + item, _, _ = arrow_table_all_data_types("table", include_json=False, include_time=True) info = pipeline.run(item, table_name="table", loader_file_format="csv") info.raise_on_failed_jobs() job = info.load_packages[0].jobs["completed_jobs"][0].file_path diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index bf57bb0c4e..50c14e9cda 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -65,12 +65,13 @@ def test_postgres_empty_csv_from_arrow(destination_config: DestinationTestConfig os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" os.environ["RESTORE_FROM_DESTINATION"] = "False" pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) - table, _ = arrow_table_all_data_types("table", include_json=False) + table, _, _ = arrow_table_all_data_types("table", include_json=False) load_info = pipeline.run( table.schema.empty_table(), table_name="table", loader_file_format="csv" ) assert_load_info(load_info) + assert len(load_info.load_packages[0].jobs["completed_jobs"]) == 1 job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path assert job.endswith("csv") assert_data_table_counts(pipeline, {"table": 0}) diff --git a/tests/load/utils.py b/tests/load/utils.py index d8daf996e1..3972f3ad95 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -586,6 +586,7 @@ def write_dataset( for idx, row in enumerate(rows): rows[idx] = {k: v for k, v in row.items() if v is not None} writer.write_all(columns_schema, rows) + writer.close() def prepare_load_package( diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index ad31e6240e..91997a921e 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -472,7 +472,7 @@ def test_normalize_retry(raw_normalize: Normalize) -> None: schema = raw_normalize.normalize_storage.extracted_packages.load_schema(load_id) schema.set_schema_contract("freeze") raw_normalize.normalize_storage.extracted_packages.save_schema(load_id, schema) - # will fail on contract violatiom + # will fail on contract violation with pytest.raises(NormalizeJobFailed): raw_normalize.run(None) @@ -492,6 +492,20 @@ def test_normalize_retry(raw_normalize: Normalize) -> None: assert len(table_files["issues"]) == 1 +def test_collect_metrics_on_exception(raw_normalize: Normalize) -> None: + load_id = extract_cases(raw_normalize, ["github.issues.load_page_5_duck"]) + schema = raw_normalize.normalize_storage.extracted_packages.load_schema(load_id) + schema.set_schema_contract("freeze") + raw_normalize.normalize_storage.extracted_packages.save_schema(load_id, schema) + # will fail on contract violation + with pytest.raises(NormalizeJobFailed) as job_ex: + raw_normalize.run(None) + # we excepted on a first row so nothing was written + # TODO: improve this test to write some rows in buffered writer + assert len(job_ex.value.writer_metrics) == 0 + raw_normalize.get_step_info(MockPipeline("multiprocessing_pipeline", True)) # type: ignore[abstract] + + def test_group_worker_files() -> None: files = ["f%03d" % idx for idx in range(0, 100)] diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 96159648ea..b16da73868 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -9,7 +9,7 @@ import dlt from dlt.common import json, Decimal from dlt.common.utils import uniq_id -from dlt.common.libs.pyarrow import NameNormalizationClash +from dlt.common.libs.pyarrow import NameNormalizationClash, remove_columns, normalize_py_arrow_item from dlt.pipeline.exceptions import PipelineStepFailed @@ -35,7 +35,7 @@ ], ) def test_extract_and_normalize(item_type: TArrowFormat, is_list: bool): - item, records = arrow_table_all_data_types(item_type) + item, records, data = arrow_table_all_data_types(item_type) pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="filesystem") @@ -72,19 +72,25 @@ def some_data(): assert normalized_bytes == extracted_bytes f.seek(0) - pq = pa.parquet.ParquetFile(f) - tbl = pq.read() - - # To make tables comparable exactly write the expected data to parquet and read it back - # The spark parquet writer loses timezone info - tbl_expected = pa.Table.from_pandas(pd.DataFrame(records)) - with io.BytesIO() as f: - pa.parquet.write_table(tbl_expected, f, flavor="spark") - f.seek(0) - tbl_expected = pa.parquet.read_table(f) - df_tbl = tbl_expected.to_pandas(ignore_metadata=True) + with pa.parquet.ParquetFile(f) as pq: + tbl = pq.read() + + # use original data to create data frame to preserve timestamp precision, timezones etc. + tbl_expected = pa.Table.from_pandas(pd.DataFrame(data)) + # null is removed by dlt + tbl_expected = remove_columns(tbl_expected, ["null"]) + # we want to normalize column names + tbl_expected = normalize_py_arrow_item( + tbl_expected, + pipeline.default_schema.get_table_columns("some_data"), + pipeline.default_schema.naming, + None, + ) + assert tbl_expected.schema.equals(tbl.schema) + + df_tbl = tbl_expected.to_pandas(ignore_metadata=False) # Data is identical to the original dataframe - df_result = tbl.to_pandas(ignore_metadata=True) + df_result = tbl.to_pandas(ignore_metadata=False) assert df_result.equals(df_tbl) schema = pipeline.default_schema @@ -116,7 +122,7 @@ def some_data(): def test_normalize_jsonl(item_type: TArrowFormat, is_list: bool): os.environ["DUMMY__LOADER_FILE_FORMAT"] = "jsonl" - item, records = arrow_table_all_data_types(item_type) + item, records, _ = arrow_table_all_data_types(item_type, tz="Europe/Berlin") pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="dummy") @@ -136,21 +142,18 @@ def some_data(): job = [j for j in jobs if "some_data" in j][0] with storage.normalized_packages.storage.open_file(job, "r") as f: result = [json.loads(line) for line in f] - for row in result: - row["decimal"] = Decimal(row["decimal"]) - - for record in records: - record["datetime"] = record["datetime"].replace(tzinfo=None) expected = json.loads(json.dumps(records)) - for record in expected: - record["decimal"] = Decimal(record["decimal"]) - assert result == expected + assert len(result) == len(expected) + for res_item, exp_item in zip(result, expected): + res_item["decimal"] = Decimal(res_item["decimal"]) + exp_item["decimal"] = Decimal(exp_item["decimal"]) + assert res_item == exp_item @pytest.mark.parametrize("item_type", ["table", "record_batch"]) def test_add_map(item_type: TArrowFormat): - item, records = arrow_table_all_data_types(item_type, num_rows=200) + item, _, _ = arrow_table_all_data_types(item_type, num_rows=200) @dlt.resource def some_data(): @@ -180,7 +183,7 @@ def test_extract_normalize_file_rotation(item_type: TArrowFormat) -> None: pipeline_name = "arrow_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") - item, rows = arrow_table_all_data_types(item_type) + item, rows, _ = arrow_table_all_data_types(item_type) @dlt.resource def data_frames(): @@ -209,7 +212,7 @@ def test_arrow_clashing_names(item_type: TArrowFormat) -> None: pipeline_name = "arrow_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") - item, _ = arrow_table_all_data_types(item_type, include_name_clash=True) + item, _, _ = arrow_table_all_data_types(item_type, include_name_clash=True) @dlt.resource def data_frames(): @@ -226,10 +229,10 @@ def test_load_arrow_vary_schema(item_type: TArrowFormat) -> None: pipeline_name = "arrow_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") - item, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) + item, _, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) pipeline.run(item, table_name="data").raise_on_failed_jobs() - item, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) + item, _, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) # remove int column try: item = item.drop("int") @@ -245,7 +248,7 @@ def test_arrow_as_data_loading(item_type: TArrowFormat) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "False" os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" - item, rows = arrow_table_all_data_types(item_type) + item, rows, _ = arrow_table_all_data_types(item_type) item_resource = dlt.resource(item, name="item") assert id(item) == id(list(item_resource)[0]) @@ -260,7 +263,7 @@ def test_arrow_as_data_loading(item_type: TArrowFormat) -> None: @pytest.mark.parametrize("item_type", ["table"]) # , "pandas", "record_batch" def test_normalize_with_dlt_columns(item_type: TArrowFormat): - item, records = arrow_table_all_data_types(item_type, num_rows=5432) + item, records, _ = arrow_table_all_data_types(item_type, num_rows=5432) os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_ID"] = "True" # Test with buffer smaller than the number of batches to be written @@ -316,7 +319,7 @@ def some_data(): pipeline.run(item, table_name="some_data").raise_on_failed_jobs() # should be able to load arrow with a new column - item, records = arrow_table_all_data_types(item_type, num_rows=200) + item, records, _ = arrow_table_all_data_types(item_type, num_rows=200) item = item.append_column("static_int", [[0] * 200]) pipeline.run(item, table_name="some_data").raise_on_failed_jobs() @@ -475,7 +478,7 @@ def test_empty_arrow(item_type: TArrowFormat) -> None: os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" # always return pandas - item, _ = arrow_table_all_data_types("pandas", num_rows=1) + item, _, _ = arrow_table_all_data_types("pandas", num_rows=1) item_resource = dlt.resource(item, name="items", write_disposition="replace") pipeline_name = "arrow_" + uniq_id()