From 8b6028bcb50df14a4323dc800a5395fedb1f59e2 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 11 Jan 2024 13:59:02 +0100 Subject: [PATCH 01/35] start sink --- dlt/destinations/impl/sink/__init__.py | 5 ++ dlt/destinations/impl/sink/configuration.py | 27 +++++++ dlt/destinations/impl/sink/factory.py | 39 +++++++++ dlt/destinations/impl/sink/sink.py | 89 +++++++++++++++++++++ 4 files changed, 160 insertions(+) create mode 100644 dlt/destinations/impl/sink/__init__.py create mode 100644 dlt/destinations/impl/sink/configuration.py create mode 100644 dlt/destinations/impl/sink/factory.py create mode 100644 dlt/destinations/impl/sink/sink.py diff --git a/dlt/destinations/impl/sink/__init__.py b/dlt/destinations/impl/sink/__init__.py new file mode 100644 index 0000000000..2532b254fa --- /dev/null +++ b/dlt/destinations/impl/sink/__init__.py @@ -0,0 +1,5 @@ +from dlt.common.destination import DestinationCapabilitiesContext + + +def capabilities() -> DestinationCapabilitiesContext: + return DestinationCapabilitiesContext.generic_capabilities("parquet") diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py new file mode 100644 index 0000000000..3d5f90bfbc --- /dev/null +++ b/dlt/destinations/impl/sink/configuration.py @@ -0,0 +1,27 @@ +from typing import TYPE_CHECKING, Optional, Final + +from dlt.common.configuration import configspec +from dlt.common.destination import TLoaderFileFormat +from dlt.common.destination.reference import ( + DestinationClientConfiguration, + CredentialsConfiguration, +) + + +@configspec +class SinkClientCredentials(CredentialsConfiguration): + callable_name: str = None + + +@configspec +class SinkClientConfiguration(DestinationClientConfiguration): + destination_type: Final[str] = "sink" # type: ignore + credentials: SinkClientCredentials = None + + if TYPE_CHECKING: + + def __init__( + self, + *, + credentials: Optional[CredentialsConfiguration] = None, + ) -> None: ... diff --git a/dlt/destinations/impl/sink/factory.py b/dlt/destinations/impl/sink/factory.py new file mode 100644 index 0000000000..a69f29598d --- /dev/null +++ b/dlt/destinations/impl/sink/factory.py @@ -0,0 +1,39 @@ +import typing as t + +from dlt.common.destination import Destination, DestinationCapabilitiesContext + +from dlt.destinations.impl.sink.configuration import ( + SinkClientConfiguration, + SinkClientCredentials, +) +from dlt.destinations.impl.sink import capabilities + +if t.TYPE_CHECKING: + from dlt.destinations.impl.sink.sink import SinkClient + + +class sink(Destination[SinkClientConfiguration, "SinkClient"]): + spec = SinkClientConfiguration + + def capabilities(self) -> DestinationCapabilitiesContext: + return capabilities() + + @property + def client_class(self) -> t.Type["SinkClient"]: + from dlt.destinations.impl.sink.sink import SinkClient + + return SinkClient + + def __init__( + self, + credentials: SinkClientCredentials = None, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py new file mode 100644 index 0000000000..df1acf744b --- /dev/null +++ b/dlt/destinations/impl/sink/sink.py @@ -0,0 +1,89 @@ +import random +from copy import copy +from types import TracebackType +from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, List + +from dlt.destinations.job_impl import EmptyLoadJob + +from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.storages import FileStorage +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import ( + FollowupJob, + NewLoadJob, + TLoadJobState, + LoadJob, + JobClientBase, +) + +from dlt.destinations.exceptions import ( + LoadJobNotExistsException, +) + +from dlt.destinations.impl.Sink import capabilities +from dlt.destinations.impl.Sink.configuration import SinkClientConfiguration + + +class LoadSinkJob(LoadJob, FollowupJob): + def __init__(self, file_path: str, config: SinkClientConfiguration) -> None: + self._file_path = file_path + self._config = config + + def state(self) -> TLoadJobState: + return "completed" + + def exception(self) -> str: + raise NotImplementedError() + + +JOBS: Dict[str, LoadSinkJob] = {} + + +class SinkClient(JobClientBase): + """Sink client storing jobs in memory""" + + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: + super().__init__(schema, config) + self.config: SinkClientConfiguration = config + + def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: + pass + + def is_storage_initialized(self) -> bool: + return True + + def drop_storage(self) -> None: + pass + + def update_stored_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> Optional[TSchemaTables]: + return super().update_stored_schema(only_tables, expected_update) + + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + return LoadSinkJob(file_path, config=self.config) + + def restore_file_load(self, file_path: str) -> LoadJob: + return EmptyLoadJob.from_file_path(file_path, "completed") + + def create_table_chain_completed_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: + """Creates a list of followup jobs that should be executed after a table chain is completed""" + return [] + + def complete_load(self, load_id: str) -> None: + pass + + def __enter__(self) -> "SinkClient": + return self + + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: + pass + + def _create_job(self, job_id: str) -> LoadSinkJob: + return LoadSinkJob(job_id, config=self.config) From 6a924d9e8cb7150788bc97426a58e4189c7682ec Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 11 Jan 2024 21:25:56 +0100 Subject: [PATCH 02/35] parquet sink prototype --- dlt/__init__.py | 3 +++ dlt/destinations/decorators.py | 11 +++++++++++ dlt/destinations/impl/sink/configuration.py | 16 ++++++++++++---- dlt/destinations/impl/sink/factory.py | 1 + dlt/destinations/impl/sink/sink.py | 13 +++++++++++-- tests/load/sink/__init__.py | 0 tests/load/sink/test_simple_sink.py | 14 ++++++++++++++ 7 files changed, 52 insertions(+), 6 deletions(-) create mode 100644 dlt/destinations/decorators.py create mode 100644 tests/load/sink/__init__.py create mode 100644 tests/load/sink/test_simple_sink.py diff --git a/dlt/__init__.py b/dlt/__init__.py index e2a6b1a3a7..6b567b3398 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -29,6 +29,8 @@ from dlt import sources from dlt.extract.decorators import source, resource, transformer, defer +from dlt.destinations.decorators import sink + from dlt.pipeline import ( pipeline as _pipeline, run, @@ -62,6 +64,7 @@ "resource", "transformer", "defer", + "sink", "pipeline", "run", "attach", diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py new file mode 100644 index 0000000000..d9d543e813 --- /dev/null +++ b/dlt/destinations/decorators.py @@ -0,0 +1,11 @@ +from typing import Any, Callable +from dlt.destinations.impl.sink.factory import sink as _sink +from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable +from dlt.common.destination import TDestinationReferenceArg + + +def sink() -> Any: + def decorator(f: TSinkCallable) -> TDestinationReferenceArg: + return _sink(credentials=f) + + return decorator diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index 3d5f90bfbc..7a1a2d45a8 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Final +from typing import TYPE_CHECKING, Optional, Final, Callable, Union, Any from dlt.common.configuration import configspec from dlt.common.destination import TLoaderFileFormat @@ -6,22 +6,30 @@ DestinationClientConfiguration, CredentialsConfiguration, ) +from dlt.common.typing import TDataItems + + +TSinkCallable = Callable[[TDataItems], None] @configspec class SinkClientCredentials(CredentialsConfiguration): - callable_name: str = None + callable_name: Optional[str] = None + + def parse_native_representation(self, native_value: Any) -> None: + if callable(native_value): + self.callable: TSinkCallable = native_value @configspec class SinkClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = "sink" # type: ignore - credentials: SinkClientCredentials = None + credentials: Union[SinkClientCredentials, TSinkCallable] = None if TYPE_CHECKING: def __init__( self, *, - credentials: Optional[CredentialsConfiguration] = None, + credentials: Union[SinkClientCredentials, TSinkCallable] = None, ) -> None: ... diff --git a/dlt/destinations/impl/sink/factory.py b/dlt/destinations/impl/sink/factory.py index a69f29598d..488913b85a 100644 --- a/dlt/destinations/impl/sink/factory.py +++ b/dlt/destinations/impl/sink/factory.py @@ -5,6 +5,7 @@ from dlt.destinations.impl.sink.configuration import ( SinkClientConfiguration, SinkClientCredentials, + TSinkCallable, ) from dlt.destinations.impl.sink import capabilities diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index df1acf744b..6b3769d3c0 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -20,15 +20,24 @@ LoadJobNotExistsException, ) -from dlt.destinations.impl.Sink import capabilities -from dlt.destinations.impl.Sink.configuration import SinkClientConfiguration +from dlt.destinations.impl.sink import capabilities +from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable class LoadSinkJob(LoadJob, FollowupJob): def __init__(self, file_path: str, config: SinkClientConfiguration) -> None: + super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config + # stream items + from dlt.common.libs.pyarrow import pyarrow + + with pyarrow.parquet.ParquetFile(file_path) as reader: + for record_batch in reader.iter_batches(batch_size=10): + for d in record_batch.to_pylist(): + self._config.credentials.callable(d) + def state(self) -> TLoadJobState: return "completed" diff --git a/tests/load/sink/__init__.py b/tests/load/sink/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/sink/test_simple_sink.py b/tests/load/sink/test_simple_sink.py new file mode 100644 index 0000000000..25cfb6791a --- /dev/null +++ b/tests/load/sink/test_simple_sink.py @@ -0,0 +1,14 @@ +import dlt +from dlt.common.typing import TDataItems + + +def test_datasink() -> None: + @dlt.sink() + def test_sink(items: TDataItems) -> None: + print("CALL") + print(items) + + p = dlt.pipeline("sink_test", destination=test_sink) + + p.run([{"a": "b "}], table_name="items") + assert False From 83c1af63b6caecc4376961ae525ca811e4867588 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 12 Jan 2024 13:02:21 +0100 Subject: [PATCH 03/35] some more sink implementations --- dlt/destinations/decorators.py | 5 +- dlt/destinations/impl/sink/__init__.py | 7 +- dlt/destinations/impl/sink/configuration.py | 7 +- dlt/destinations/impl/sink/factory.py | 9 +- dlt/destinations/impl/sink/sink.py | 98 ++++++++++++++++++--- tests/load/sink/test_simple_sink.py | 14 --- tests/load/sink/test_sink.py | 77 ++++++++++++++++ 7 files changed, 181 insertions(+), 36 deletions(-) delete mode 100644 tests/load/sink/test_simple_sink.py create mode 100644 tests/load/sink/test_sink.py diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index d9d543e813..a99690cff7 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -2,10 +2,11 @@ from dlt.destinations.impl.sink.factory import sink as _sink from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable from dlt.common.destination import TDestinationReferenceArg +from dlt.common.destination import TLoaderFileFormat -def sink() -> Any: +def sink(loader_file_format: TLoaderFileFormat = None, batch_size: int = 10) -> Any: def decorator(f: TSinkCallable) -> TDestinationReferenceArg: - return _sink(credentials=f) + return _sink(credentials=f, loader_file_format=loader_file_format, batch_size=batch_size) return decorator diff --git a/dlt/destinations/impl/sink/__init__.py b/dlt/destinations/impl/sink/__init__.py index 2532b254fa..a5f7bd268b 100644 --- a/dlt/destinations/impl/sink/__init__.py +++ b/dlt/destinations/impl/sink/__init__.py @@ -1,5 +1,8 @@ from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.data_writers import TLoaderFileFormat -def capabilities() -> DestinationCapabilitiesContext: - return DestinationCapabilitiesContext.generic_capabilities("parquet") +def capabilities( + preferred_loader_file_format: TLoaderFileFormat = "parquet", +) -> DestinationCapabilitiesContext: + return DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index 7a1a2d45a8..591f31fcce 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -7,9 +7,10 @@ CredentialsConfiguration, ) from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema -TSinkCallable = Callable[[TDataItems], None] +TSinkCallable = Callable[[TDataItems, TTableSchema], None] @configspec @@ -24,7 +25,9 @@ def parse_native_representation(self, native_value: Any) -> None: @configspec class SinkClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = "sink" # type: ignore - credentials: Union[SinkClientCredentials, TSinkCallable] = None + credentials: SinkClientCredentials = None + loader_file_format: TLoaderFileFormat = "parquet" + batch_size: int = 10 if TYPE_CHECKING: diff --git a/dlt/destinations/impl/sink/factory.py b/dlt/destinations/impl/sink/factory.py index 488913b85a..f51c3386ab 100644 --- a/dlt/destinations/impl/sink/factory.py +++ b/dlt/destinations/impl/sink/factory.py @@ -8,6 +8,7 @@ TSinkCallable, ) from dlt.destinations.impl.sink import capabilities +from dlt.common.data_writers import TLoaderFileFormat if t.TYPE_CHECKING: from dlt.destinations.impl.sink.sink import SinkClient @@ -17,7 +18,7 @@ class sink(Destination[SinkClientConfiguration, "SinkClient"]): spec = SinkClientConfiguration def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities() + return capabilities(self.config_params.get("loader_file_format", "parquet")) @property def client_class(self) -> t.Type["SinkClient"]: @@ -27,14 +28,18 @@ def client_class(self) -> t.Type["SinkClient"]: def __init__( self, - credentials: SinkClientCredentials = None, + credentials: t.Union[SinkClientCredentials, TSinkCallable] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, + loader_file_format: TLoaderFileFormat = None, + batch_size: int = 10, **kwargs: t.Any, ) -> None: super().__init__( credentials=credentials, destination_name=destination_name, environment=environment, + loader_file_format=loader_file_format, + batch_size=batch_size, **kwargs, ) diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 6b3769d3c0..7ee51972b9 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -4,8 +4,10 @@ from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, List from dlt.destinations.job_impl import EmptyLoadJob +from dlt.common.typing import TDataItems from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( @@ -19,24 +21,30 @@ from dlt.destinations.exceptions import ( LoadJobNotExistsException, ) - from dlt.destinations.impl.sink import capabilities from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable -class LoadSinkJob(LoadJob, FollowupJob): - def __init__(self, file_path: str, config: SinkClientConfiguration) -> None: +class SinkLoadJob(LoadJob, FollowupJob): + def __init__( + self, table: TTableSchema, file_path: str, config: SinkClientConfiguration + ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config + self._table = table + self.run() - # stream items - from dlt.common.libs.pyarrow import pyarrow + def run(self) -> None: + pass - with pyarrow.parquet.ParquetFile(file_path) as reader: - for record_batch in reader.iter_batches(batch_size=10): - for d in record_batch.to_pylist(): - self._config.credentials.callable(d) + def call_callable_with_items(self, items: TDataItems) -> None: + if not items: + return + if self._config.credentials.callable: + self._config.credentials.callable( + items[0] if self._config.batch_size == 1 else items, self._table + ) def state(self) -> TLoadJobState: return "completed" @@ -45,7 +53,66 @@ def exception(self) -> str: raise NotImplementedError() -JOBS: Dict[str, LoadSinkJob] = {} +class SinkParquetLoadJob(SinkLoadJob): + def run(self) -> None: + # stream items + from dlt.common.libs.pyarrow import pyarrow + + with pyarrow.parquet.ParquetFile(self._file_path) as reader: + for record_batch in reader.iter_batches(batch_size=self._config.batch_size): + batch = record_batch.to_pylist() + self.call_callable_with_items(batch) + + +class SinkJsonlLoadJob(SinkLoadJob): + def run(self) -> None: + from dlt.common import json + + # stream items + with FileStorage.open_zipsafe_ro(self._file_path) as f: + current_batch: TDataItems = [] + for line in f: + current_batch.append(json.loads(line)) + if len(current_batch) == self._config.batch_size: + self.call_callable_with_items(current_batch) + current_batch = [] + self.call_callable_with_items(current_batch) + + +class SinkInsertValueslLoadJob(SinkLoadJob): + def run(self) -> None: + from dlt.common import json + + # stream items + with FileStorage.open_zipsafe_ro(self._file_path) as f: + current_batch: TDataItems = [] + column_names: List[str] = [] + for line in f: + line = line.strip() + + # TODO respect inserts with multiline values + + # extract column names + if line.startswith("INSERT INTO") and line.endswith(")"): + line = line[15:-1] + column_names = line.split(",") + continue + + # not a valid values line + if not line.startswith("(") or not line.endswith(");"): + continue + + # extract values + line = line[1:-2] + values = line.split(",") + + # zip and send to callable + current_batch.append(dict(zip(column_names, values))) + if len(current_batch) == self._config.batch_size: + self.call_callable_with_items(current_batch) + current_batch = [] + + self.call_callable_with_items(current_batch) class SinkClient(JobClientBase): @@ -72,7 +139,13 @@ def update_stored_schema( return super().update_stored_schema(only_tables, expected_update) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - return LoadSinkJob(file_path, config=self.config) + if file_path.endswith("parquet"): + return SinkParquetLoadJob(table, file_path, self.config) + if file_path.endswith("jsonl"): + return SinkJsonlLoadJob(table, file_path, self.config) + if file_path.endswith("insert_values"): + return SinkInsertValueslLoadJob(table, file_path, self.config) + return EmptyLoadJob.from_file_path(file_path, "completed") def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") @@ -93,6 +166,3 @@ def __exit__( self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType ) -> None: pass - - def _create_job(self, job_id: str) -> LoadSinkJob: - return LoadSinkJob(job_id, config=self.config) diff --git a/tests/load/sink/test_simple_sink.py b/tests/load/sink/test_simple_sink.py deleted file mode 100644 index 25cfb6791a..0000000000 --- a/tests/load/sink/test_simple_sink.py +++ /dev/null @@ -1,14 +0,0 @@ -import dlt -from dlt.common.typing import TDataItems - - -def test_datasink() -> None: - @dlt.sink() - def test_sink(items: TDataItems) -> None: - print("CALL") - print(items) - - p = dlt.pipeline("sink_test", destination=test_sink) - - p.run([{"a": "b "}], table_name="items") - assert False diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py new file mode 100644 index 0000000000..0b022fb312 --- /dev/null +++ b/tests/load/sink/test_sink.py @@ -0,0 +1,77 @@ +from typing import List, Tuple + +import dlt +import pytest + +from copy import deepcopy +from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema +from dlt.common.data_writers.writers import TLoaderFileFormat + +from tests.load.utils import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE_COLUMNS_SCHEMA, + assert_all_data_types_row, + delete_dataset, +) + +SUPPORTED_LOADER_FORMATS = ["parquet", "jsonl", "insert_values"] + + +def _run_through_sink( + items: TDataItems, + loader_file_format: TLoaderFileFormat, + columns=None, + filter_dlt_tables: bool = True, +) -> List[Tuple[TDataItems, TTableSchema]]: + """ + runs a list of items through the sink destination and returns colleceted calls + """ + calls: List[Tuple[TDataItems, TTableSchema]] = [] + + @dlt.sink(loader_file_format=loader_file_format, batch_size=1) + def test_sink(items: TDataItems, table: TTableSchema) -> None: + nonlocal calls + if table["name"].startswith("_dlt") and filter_dlt_tables: + return + calls.append((items, table)) + + @dlt.resource(columns=columns, table_name="items") + def items_resource() -> TDataItems: + nonlocal items + yield items + + p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p.run([items_resource()]) + + return calls + + +@pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) +def test_all_datatypes(loader_file_format: TLoaderFileFormat) -> None: + data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) + column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) + + sink_calls = _run_through_sink(data_types, loader_file_format, columns=column_schemas) + + # inspect result + assert len(sink_calls) == 1 + + item = sink_calls[0][0] + # filter out _dlt columns + item = {k: v for k, v in item.items() if not k.startswith("_dlt")} + + # null values are not saved in jsonl (TODO: is this correct?) + if loader_file_format == "jsonl": + data_types = {k: v for k, v in data_types.items() if v is not None} + + # check keys are the same + assert set(item.keys()) == set(data_types.keys()) + + # TODO: check actual types + # assert_all_data_types_row + + +@pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) +def test_batch_size(loader_file_format: TLoaderFileFormat) -> None: + pass From b972f237129ab2a856dc0a72f0531ef72ffedb59 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 12 Jan 2024 17:16:44 +0100 Subject: [PATCH 04/35] finish first batch of helpers --- dlt/common/data_types/type_helpers.py | 10 ++- dlt/destinations/impl/sink/sink.py | 106 +++++++++++++++----------- tests/cases.py | 25 ++++-- tests/load/sink/test_sink.py | 59 +++++++++++--- 4 files changed, 138 insertions(+), 62 deletions(-) diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index 9e1cd2278d..d84821b217 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -6,7 +6,7 @@ from enum import Enum from dlt.common import pendulum, json, Decimal, Wei -from dlt.common.json import custom_pua_remove +from dlt.common.json import custom_pua_remove, json from dlt.common.json._simplejson import custom_encode as json_custom_encode from dlt.common.arithmetics import InvalidOperation from dlt.common.data_types.typing import TDataType @@ -105,6 +105,14 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: return int(value.value) return value + if to_type == "complex": + # try to coerce from text + if from_type == "text": + try: + return json.loads(value) + except Exception as e: + pass + if to_type == "text": if from_type == "complex": return complex_to_str(value) diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 7ee51972b9..ed0d1d4686 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -27,12 +27,13 @@ class SinkLoadJob(LoadJob, FollowupJob): def __init__( - self, table: TTableSchema, file_path: str, config: SinkClientConfiguration + self, table: TTableSchema, file_path: str, config: SinkClientConfiguration, schema: Schema ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config self._table = table + self._schema = schema self.run() def run(self) -> None: @@ -41,10 +42,19 @@ def run(self) -> None: def call_callable_with_items(self, items: TDataItems) -> None: if not items: return - if self._config.credentials.callable: - self._config.credentials.callable( - items[0] if self._config.batch_size == 1 else items, self._table - ) + + # coerce items into correct format specified by schema + coerced_items: TDataItems = [] + for item in items: + coerced_item, table_update = self._schema.coerce_row(self._table["name"], None, item) + assert not table_update + coerced_items.append(coerced_item) + + # send single item on batch size 1 + if self._config.batch_size == 1: + coerced_items = coerced_items[0] + + self._config.credentials.callable(coerced_items, self._table) def state(self) -> TLoadJobState: return "completed" @@ -79,40 +89,50 @@ def run(self) -> None: self.call_callable_with_items(current_batch) -class SinkInsertValueslLoadJob(SinkLoadJob): - def run(self) -> None: - from dlt.common import json - - # stream items - with FileStorage.open_zipsafe_ro(self._file_path) as f: - current_batch: TDataItems = [] - column_names: List[str] = [] - for line in f: - line = line.strip() - - # TODO respect inserts with multiline values - - # extract column names - if line.startswith("INSERT INTO") and line.endswith(")"): - line = line[15:-1] - column_names = line.split(",") - continue - - # not a valid values line - if not line.startswith("(") or not line.endswith(");"): - continue - - # extract values - line = line[1:-2] - values = line.split(",") - - # zip and send to callable - current_batch.append(dict(zip(column_names, values))) - if len(current_batch) == self._config.batch_size: - self.call_callable_with_items(current_batch) - current_batch = [] - - self.call_callable_with_items(current_batch) +# class SinkInsertValueslLoadJob(SinkLoadJob): +# def run(self) -> None: +# from dlt.common import json + +# # stream items +# with FileStorage.open_zipsafe_ro(self._file_path) as f: +# header = f.readline().strip() +# values_mark = f.readline() + +# # properly formatted file has a values marker at the beginning +# assert values_mark == "VALUES\n" + +# # extract column names +# assert header.startswith("INSERT INTO") and header.endswith(")") +# header = header[15:-1] +# column_names = header.split(",") + +# # build batches +# current_batch: TDataItems = [] +# current_row: str = "" +# for line in f: +# current_row += line +# if line.endswith(");"): +# current_row = current_row[1:-2] +# elif line.endswith("),\n"): +# current_row = current_row[1:-3] +# else: +# continue + +# values = current_row.split(",") +# values = [None if v == "NULL" else v for v in values] +# current_row = "" +# print(values) +# print(current_row) + +# # zip and send to callable +# current_batch.append(dict(zip(column_names, values))) +# d = dict(zip(column_names, values)) +# print(json.dumps(d, pretty=True)) +# if len(current_batch) == self._config.batch_size: +# self.call_callable_with_items(current_batch) +# current_batch = [] + +# self.call_callable_with_items(current_batch) class SinkClient(JobClientBase): @@ -140,11 +160,11 @@ def update_stored_schema( def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: if file_path.endswith("parquet"): - return SinkParquetLoadJob(table, file_path, self.config) + return SinkParquetLoadJob(table, file_path, self.config, self.schema) if file_path.endswith("jsonl"): - return SinkJsonlLoadJob(table, file_path, self.config) - if file_path.endswith("insert_values"): - return SinkInsertValueslLoadJob(table, file_path, self.config) + return SinkJsonlLoadJob(table, file_path, self.config, self.schema) + # if file_path.endswith("insert_values"): + # return SinkInsertValueslLoadJob(table, file_path, self.config, self.schema) return EmptyLoadJob.from_file_path(file_path, "completed") def restore_file_load(self, file_path: str) -> LoadJob: diff --git a/tests/cases.py b/tests/cases.py index 8653f999c6..a52d68b230 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Sequence, Tuple, Literal +from typing import Dict, List, Any, Sequence, Tuple, Literal, Union import base64 from hexbytes import HexBytes from copy import deepcopy @@ -7,7 +7,7 @@ from dlt.common import Decimal, pendulum, json from dlt.common.data_types import TDataType -from dlt.common.typing import StrAny +from dlt.common.typing import StrAny, TDataItems from dlt.common.wei import Wei from dlt.common.time import ( ensure_pendulum_datetime, @@ -161,18 +161,23 @@ def table_update_and_row( def assert_all_data_types_row( - db_row: List[Any], + db_row: Union[List[Any], TDataItems], parse_complex_strings: bool = False, allow_base64_binary: bool = False, timestamp_precision: int = 6, schema: TTableSchemaColumns = None, + expect_filtered_null_columns=False, ) -> None: # content must equal # print(db_row) schema = schema or TABLE_UPDATE_COLUMNS_SCHEMA # Include only columns requested in schema - db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)} + if isinstance(db_row, dict): + db_mapping = db_row.copy() + else: + db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)} + expected_rows = {key: value for key, value in TABLE_ROW_ALL_DATA_TYPES.items() if key in schema} # prepare date to be compared: convert into pendulum instance, adjust microsecond precision if "col4" in expected_rows: @@ -226,8 +231,16 @@ def assert_all_data_types_row( if "col11" in db_mapping: db_mapping["col11"] = db_mapping["col11"].isoformat() - for expected, actual in zip(expected_rows.values(), db_mapping.values()): - assert expected == actual + if expect_filtered_null_columns: + for key, expected in expected_rows.items(): + if expected is None: + assert db_mapping.get(key, None) == None + db_mapping[key] = None + + for key, expected in expected_rows.items(): + actual = db_mapping[key] + assert expected == actual, f"Expected {expected} but got {actual} for column {key}" + assert db_mapping == expected_rows diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 0b022fb312..a7aed033d5 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -15,7 +15,7 @@ delete_dataset, ) -SUPPORTED_LOADER_FORMATS = ["parquet", "jsonl", "insert_values"] +SUPPORTED_LOADER_FORMATS = ["parquet", "jsonl"] def _run_through_sink( @@ -23,13 +23,14 @@ def _run_through_sink( loader_file_format: TLoaderFileFormat, columns=None, filter_dlt_tables: bool = True, + batch_size: int = 10, ) -> List[Tuple[TDataItems, TTableSchema]]: """ runs a list of items through the sink destination and returns colleceted calls """ calls: List[Tuple[TDataItems, TTableSchema]] = [] - @dlt.sink(loader_file_format=loader_file_format, batch_size=1) + @dlt.sink(loader_file_format=loader_file_format, batch_size=batch_size) def test_sink(items: TDataItems, table: TTableSchema) -> None: nonlocal calls if table["name"].startswith("_dlt") and filter_dlt_tables: @@ -52,26 +53,60 @@ def test_all_datatypes(loader_file_format: TLoaderFileFormat) -> None: data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) - sink_calls = _run_through_sink(data_types, loader_file_format, columns=column_schemas) + sink_calls = _run_through_sink( + [data_types, data_types, data_types], + loader_file_format, + columns=column_schemas, + batch_size=1, + ) # inspect result - assert len(sink_calls) == 1 + assert len(sink_calls) == 3 item = sink_calls[0][0] # filter out _dlt columns - item = {k: v for k, v in item.items() if not k.startswith("_dlt")} + item = {k: v for k, v in item.items() if not k.startswith("_dlt")} # type: ignore - # null values are not saved in jsonl (TODO: is this correct?) - if loader_file_format == "jsonl": - data_types = {k: v for k, v in data_types.items() if v is not None} + # null values are not emitted + data_types = {k: v for k, v in data_types.items() if v is not None} # check keys are the same assert set(item.keys()) == set(data_types.keys()) - # TODO: check actual types - # assert_all_data_types_row + assert_all_data_types_row(item, expect_filtered_null_columns=True) @pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) -def test_batch_size(loader_file_format: TLoaderFileFormat) -> None: - pass +@pytest.mark.parametrize("batch_size", [1, 10, 23]) +def test_batch_size(loader_file_format: TLoaderFileFormat, batch_size: int) -> None: + items = [{"id": i, "value": str(i)} for i in range(100)] + + sink_calls = _run_through_sink(items, loader_file_format, batch_size=batch_size) + + if batch_size == 1: + assert len(sink_calls) == 100 + # one item per call + assert sink_calls[0][0].items() > {"id": 0, "value": "0"}.items() # type: ignore + elif batch_size == 10: + assert len(sink_calls) == 10 + # ten items in first call + assert len(sink_calls[0][0]) == 10 + assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items() + elif batch_size == 23: + assert len(sink_calls) == 5 + # 23 items in first call + assert len(sink_calls[0][0]) == 23 + assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items() + + # check all items are present + all_items = set() + for call in sink_calls: + item = call[0] + if batch_size == 1: + item = [item] + for entry in item: + all_items.add(entry["value"]) + + assert len(all_items) == 100 + for i in range(100): + assert str(i) in all_items From 9dabeffc9cfd6841cdea4d6fd354bdbfce4497a1 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 13 Jan 2024 17:40:15 +0100 Subject: [PATCH 05/35] add missing tests and fix linting --- dlt/common/data_types/type_helpers.py | 2 +- tests/cases.py | 2 +- tests/common/schema/test_coercion.py | 10 ++++++++++ tests/load/weaviate/utils.py | 2 ++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index d84821b217..9b26207cf1 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -110,7 +110,7 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: if from_type == "text": try: return json.loads(value) - except Exception as e: + except Exception: pass if to_type == "text": diff --git a/tests/cases.py b/tests/cases.py index a52d68b230..85caec4b8d 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -234,7 +234,7 @@ def assert_all_data_types_row( if expect_filtered_null_columns: for key, expected in expected_rows.items(): if expected is None: - assert db_mapping.get(key, None) == None + assert db_mapping.get(key, None) is None db_mapping[key] = None for key, expected in expected_rows.items(): diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index 922024a89b..34b62f9564 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -377,10 +377,16 @@ def test_coerce_type_complex() -> None: assert coerce_value("complex", "complex", v_list) == v_list assert coerce_value("text", "complex", v_dict) == json.dumps(v_dict) assert coerce_value("text", "complex", v_list) == json.dumps(v_list) + assert coerce_value("complex", "text", json.dumps(v_dict)) == v_dict + assert coerce_value("complex", "text", json.dumps(v_list)) == v_list + # all other coercions fail with pytest.raises(ValueError): coerce_value("binary", "complex", v_list) + with pytest.raises(ValueError): + coerce_value("complex", "text", "not a json string") + def test_coerce_type_complex_with_pua() -> None: v_dict = { @@ -395,6 +401,10 @@ def test_coerce_type_complex_with_pua() -> None: } assert coerce_value("complex", "complex", copy(v_dict)) == exp_v assert coerce_value("text", "complex", copy(v_dict)) == json.dumps(exp_v) + + # TODO: what to test for this case if at all? + # assert coerce_value("complex", "text", json.dumps(v_dict)) == exp_v + # also decode recursively custom_pua_decode_nested(v_dict) # restores datetime type diff --git a/tests/load/weaviate/utils.py b/tests/load/weaviate/utils.py index ed378191e6..1b2a74fcb8 100644 --- a/tests/load/weaviate/utils.py +++ b/tests/load/weaviate/utils.py @@ -79,6 +79,8 @@ def delete_classes(p, class_list): def drop_active_pipeline_data() -> None: def schema_has_classes(client): + if not hasattr(client, "db_client"): + return None schema = client.db_client.schema.get() return schema["classes"] From af6defd0bc5a58bee12498bbdce479397686e4af Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 15 Jan 2024 14:18:38 +0100 Subject: [PATCH 06/35] make configuratio more versatile --- dlt/destinations/__init__.py | 1 + dlt/destinations/impl/sink/configuration.py | 36 ++++++++++- dlt/destinations/impl/sink/sink.py | 7 +-- tests/load/sink/test_sink.py | 70 +++++++++++++++++++++ tests/utils.py | 5 +- 5 files changed, 110 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 980c4ce7f2..b45daf0513 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -10,6 +10,7 @@ from dlt.destinations.impl.qdrant.factory import qdrant from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate +from dlt.destinations.impl.sink.factory import sink __all__ = [ diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index 591f31fcce..32a1410c43 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Optional, Final, Callable, Union, Any +from importlib import import_module from dlt.common.configuration import configspec from dlt.common.destination import TLoaderFileFormat @@ -8,6 +9,7 @@ ) from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema +from dlt.common.configuration.exceptions import ConfigurationValueError TSinkCallable = Callable[[TDataItems, TTableSchema], None] @@ -15,11 +17,37 @@ @configspec class SinkClientCredentials(CredentialsConfiguration): - callable_name: Optional[str] = None + callable: Optional[str] = None # noqa: A003 def parse_native_representation(self, native_value: Any) -> None: + # a callable was passed in if callable(native_value): - self.callable: TSinkCallable = native_value + self.resolved_callable: TSinkCallable = native_value + # a path to a callable was passed in + if isinstance(native_value, str): + self.callable = native_value + + def to_native_representation(self) -> Any: + return self.resolved_callable + + def on_resolved(self) -> None: + if self.callable: + try: + module_path, attr_name = self.callable.rsplit(".", 1) + dest_module = import_module(module_path) + except ModuleNotFoundError as e: + raise ConfigurationValueError( + f"Could not find callable module at {module_path}" + ) from e + try: + self.resolved_callable = getattr(dest_module, attr_name) + except AttributeError as e: + raise ConfigurationValueError( + f"Could not find callable function at {self.callable}" + ) from e + + if not hasattr(self, "resolved_callable"): + raise ConfigurationValueError("Please specify callable for sink destination.") @configspec @@ -34,5 +62,7 @@ class SinkClientConfiguration(DestinationClientConfiguration): def __init__( self, *, - credentials: Union[SinkClientCredentials, TSinkCallable] = None, + credentials: Union[SinkClientCredentials, TSinkCallable, str] = None, + loader_file_format: TLoaderFileFormat = "parquet", + batch_size: int = 10, ) -> None: ... diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index ed0d1d4686..55dd0d9b0b 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -18,9 +18,6 @@ JobClientBase, ) -from dlt.destinations.exceptions import ( - LoadJobNotExistsException, -) from dlt.destinations.impl.sink import capabilities from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable @@ -54,7 +51,7 @@ def call_callable_with_items(self, items: TDataItems) -> None: if self._config.batch_size == 1: coerced_items = coerced_items[0] - self._config.credentials.callable(coerced_items, self._table) + self._config.credentials.resolved_callable(coerced_items, self._table) def state(self) -> TLoadJobState: return "completed" @@ -136,7 +133,7 @@ def run(self) -> None: class SinkClient(JobClientBase): - """Sink client storing jobs in memory""" + """Sink Client""" capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index a7aed033d5..151eea6a45 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -7,6 +7,8 @@ from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema from dlt.common.data_writers.writers import TLoaderFileFormat +from dlt.common.destination.reference import Destination +from dlt.common.configuration.exceptions import ConfigurationValueError from tests.load.utils import ( TABLE_ROW_ALL_DATA_TYPES, @@ -110,3 +112,71 @@ def test_batch_size(loader_file_format: TLoaderFileFormat, batch_size: int) -> N assert len(all_items) == 100 for i in range(100): assert str(i) in all_items + + +global_calls: List[Tuple[TDataItems, TTableSchema]] = [] + + +def global_sink_func(items: TDataItems, table: TTableSchema) -> None: + global global_calls + if table["name"].startswith("_dlt"): + return + global_calls.append((items, table)) + + +def test_instantiation() -> None: + calls: List[Tuple[TDataItems, TTableSchema]] = [] + + def local_sink_func(items: TDataItems, table: TTableSchema) -> None: + nonlocal calls + if table["name"].startswith("_dlt"): + return + calls.append((items, table)) + + # test decorator + calls = [] + p = dlt.pipeline("sink_test", destination=dlt.sink()(local_sink_func), full_refresh=True) + p.run([1, 2, 3], table_name="items") + assert len(calls) == 1 + + # test passing via credentials + calls = [] + p = dlt.pipeline( + "sink_test", destination="sink", credentials=local_sink_func, full_refresh=True + ) + p.run([1, 2, 3], table_name="items") + assert len(calls) == 1 + + # test passing via from_reference + calls = [] + p = dlt.pipeline( + "sink_test", + destination=Destination.from_reference("sink", credentials=local_sink_func), # type: ignore + full_refresh=True, + ) + p.run([1, 2, 3], table_name="items") + assert len(calls) == 1 + + # test passing string reference + global global_calls + global_calls = [] + p = dlt.pipeline( + "sink_test", + destination="sink", + credentials="tests.load.sink.test_sink.global_sink_func", + full_refresh=True, + ) + p.run([1, 2, 3], table_name="items") + assert len(global_calls) == 1 + + # pass None credentials reference + p = dlt.pipeline("sink_test", destination="sink", credentials=None, full_refresh=True) + with pytest.raises(ConfigurationValueError): + p.run([1, 2, 3], table_name="items") + + # pass invalid credentials module + p = dlt.pipeline( + "sink_test", destination="sink", credentials="does.not.exist.callable", full_refresh=True + ) + with pytest.raises(ConfigurationValueError): + p.run([1, 2, 3], table_name="items") diff --git a/tests/utils.py b/tests/utils.py index cf172f9733..7090b49b55 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,8 +45,9 @@ "motherduck", "mssql", "qdrant", + "sink", } -NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant"} +NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "sink"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS # exclude destination configs (for now used for athena and athena iceberg separation) @@ -58,6 +59,8 @@ # filter out active destinations for current tests ACTIVE_DESTINATIONS = set(dlt.config.get("ACTIVE_DESTINATIONS", list) or IMPLEMENTED_DESTINATIONS) +ACTIVE_DESTINATIONS = {"sink"} + ACTIVE_SQL_DESTINATIONS = SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) ACTIVE_NON_SQL_DESTINATIONS = NON_SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) From 4d730e8f2d541c8531371bfe1a9e2f68bed06907 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 15 Jan 2024 17:14:42 +0100 Subject: [PATCH 07/35] implement sink function progress state --- dlt/destinations/impl/sink/configuration.py | 8 ++ dlt/destinations/impl/sink/sink.py | 82 ++++++++++++++------ tests/load/sink/test_sink.py | 86 ++++++++++++++++++++- 3 files changed, 152 insertions(+), 24 deletions(-) diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index 32a1410c43..bdff6cf58d 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -18,6 +18,8 @@ @configspec class SinkClientCredentials(CredentialsConfiguration): callable: Optional[str] = None # noqa: A003 + # name provides namespace for callable state saving + name: Optional[str] = None def parse_native_representation(self, native_value: Any) -> None: # a callable was passed in @@ -49,6 +51,12 @@ def on_resolved(self) -> None: if not hasattr(self, "resolved_callable"): raise ConfigurationValueError("Please specify callable for sink destination.") + if not callable(self.resolved_callable): + raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") + + if not self.name: + self.name = self.resolved_callable.__name__ + @configspec class SinkClientConfiguration(DestinationClientConfiguration): diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 55dd0d9b0b..1ef54ba962 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -1,18 +1,17 @@ -import random -from copy import copy +from abc import ABC, abstractmethod from types import TracebackType from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, List from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems +from dlt.common import json +from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - FollowupJob, - NewLoadJob, TLoadJobState, LoadJob, JobClientBase, @@ -22,18 +21,38 @@ from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable -class SinkLoadJob(LoadJob, FollowupJob): +# TODO: implement proper state storage somewhere, can this somehow go into the loadpackage? +job_execution_storage: Dict[str, int] = {} + + +class SinkLoadJob(LoadJob, ABC): def __init__( - self, table: TTableSchema, file_path: str, config: SinkClientConfiguration, schema: Schema + self, + table: TTableSchema, + file_path: str, + config: SinkClientConfiguration, + schema: Schema, + job_execution_storage: Dict[str, int], ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config self._table = table self._schema = schema - self.run() - - def run(self) -> None: + self._job_execution_storage = job_execution_storage + + # TODO: is this the correct way to tell dlt to retry this job in the next attempt? + self._state: TLoadJobState = "running" + try: + start_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0) + self.run(start_index) + self._state = "completed" + except Exception as e: + self._state = "retry" + raise e + + @abstractmethod + def run(self, start_index: int) -> None: pass def call_callable_with_items(self, items: TDataItems) -> None: @@ -51,34 +70,51 @@ def call_callable_with_items(self, items: TDataItems) -> None: if self._config.batch_size == 1: coerced_items = coerced_items[0] + # call callable self._config.credentials.resolved_callable(coerced_items, self._table) + # if there was no exception we assume the callable call was successful and we advance the index + current_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0) + self._job_execution_storage[self._parsed_file_name.file_id] = current_index + len(items) + def state(self) -> TLoadJobState: - return "completed" + return self._state def exception(self) -> str: raise NotImplementedError() class SinkParquetLoadJob(SinkLoadJob): - def run(self) -> None: + def run(self, start_index: int) -> None: # stream items from dlt.common.libs.pyarrow import pyarrow + # guard against changed batch size after restart of loadjob + assert ( + start_index % self._config.batch_size + ) == 0, "Batch size was changed during processing of one load package" + + start_batch = start_index / self._config.batch_size with pyarrow.parquet.ParquetFile(self._file_path) as reader: for record_batch in reader.iter_batches(batch_size=self._config.batch_size): + if start_batch > 0: + start_batch -= 1 + continue batch = record_batch.to_pylist() self.call_callable_with_items(batch) class SinkJsonlLoadJob(SinkLoadJob): - def run(self) -> None: - from dlt.common import json + def run(self, start_index: int) -> None: + current_batch: TDataItems = [] # stream items with FileStorage.open_zipsafe_ro(self._file_path) as f: - current_batch: TDataItems = [] for line in f: + # find correct start position + if start_index > 0: + start_index -= 1 + continue current_batch.append(json.loads(line)) if len(current_batch) == self._config.batch_size: self.call_callable_with_items(current_batch) @@ -140,6 +176,8 @@ class SinkClient(JobClientBase): def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: super().__init__(schema, config) self.config: SinkClientConfiguration = config + global job_execution_storage + self.job_execution_storage = job_execution_storage def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: pass @@ -157,22 +195,20 @@ def update_stored_schema( def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: if file_path.endswith("parquet"): - return SinkParquetLoadJob(table, file_path, self.config, self.schema) + return SinkParquetLoadJob( + table, file_path, self.config, self.schema, job_execution_storage + ) if file_path.endswith("jsonl"): - return SinkJsonlLoadJob(table, file_path, self.config, self.schema) + return SinkJsonlLoadJob( + table, file_path, self.config, self.schema, job_execution_storage + ) # if file_path.endswith("insert_values"): # return SinkInsertValueslLoadJob(table, file_path, self.config, self.schema) - return EmptyLoadJob.from_file_path(file_path, "completed") + return None def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def create_table_chain_completed_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - """Creates a list of followup jobs that should be executed after a table chain is completed""" - return [] - def complete_load(self, load_id: str) -> None: pass diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 151eea6a45..2be7af5513 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -1,7 +1,8 @@ -from typing import List, Tuple +from typing import List, Tuple, Dict import dlt import pytest +import pytest from copy import deepcopy from dlt.common.typing import TDataItems @@ -9,6 +10,7 @@ from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.destination.reference import Destination from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.pipeline.exceptions import PipelineStepFailed from tests.load.utils import ( TABLE_ROW_ALL_DATA_TYPES, @@ -180,3 +182,85 @@ def local_sink_func(items: TDataItems, table: TTableSchema) -> None: ) with pytest.raises(ConfigurationValueError): p.run([1, 2, 3], table_name="items") + + +@pytest.mark.parametrize("loader_file_format", ["jsonl"]) +@pytest.mark.parametrize("batch_size", [1, 10, 23]) +def test_batched_transactions(loader_file_format: TLoaderFileFormat, batch_size: int) -> None: + calls: Dict[str, List[TDataItems]] = {} + # provoke errors on resources + provoke_error: Dict[str, int] = {} + + @dlt.sink(loader_file_format=loader_file_format, batch_size=batch_size) + def test_sink(items: TDataItems, table: TTableSchema) -> None: + nonlocal calls + table_name = table["name"] + if table_name.startswith("_dlt"): + return + + # provoke error if configured + if table_name in provoke_error: + for item in items if batch_size > 1 else [items]: + if provoke_error[table_name] == item["id"]: + raise AssertionError("Oh no!") + + calls.setdefault(table_name, []).append(items if batch_size > 1 else [items]) + + @dlt.resource() + def items() -> TDataItems: + for i in range(100): + yield {"id": i, "value": str(i)} + + @dlt.resource() + def items2() -> TDataItems: + for i in range(100): + yield {"id": i, "value": str(i)} + + def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: + """ + Ensure all items where called and no duplicates are present + """ + collected_items = set() + for call in c: + for item in call: + assert item["value"] not in collected_items + collected_items.add(item["value"]) + assert len(collected_items) == end - start + for i in range(start, end): + assert str(i) in collected_items + + # no errors are set, all items should be processed + p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p.run([items(), items2()]) + assert_items_in_range(calls["items"], 0, 100) + assert_items_in_range(calls["items2"], 0, 100) + + # provoke errors + calls = {} + provoke_error = {"items": 25, "items2": 45} + p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + with pytest.raises(PipelineStepFailed): + p.run([items(), items2()]) + + # partly loaded + if batch_size == 1: + assert_items_in_range(calls["items"], 0, 25) + assert_items_in_range(calls["items2"], 0, 45) + elif batch_size == 10: + assert_items_in_range(calls["items"], 0, 20) + assert_items_in_range(calls["items2"], 0, 40) + elif batch_size == 23: + assert_items_in_range(calls["items"], 0, 23) + assert_items_in_range(calls["items2"], 0, 23) + else: + raise AssertionError("Unknown batch size") + + # load the rest + first_calls = deepcopy(calls) + provoke_error = {} + calls = {} + p.load() + + # both calls combined should have every item called just once + assert_items_in_range(calls["items"] + first_calls["items"], 0, 100) + assert_items_in_range(calls["items2"] + first_calls["items2"], 0, 100) From 3b577b0296b95c8c0fbb185ba8c3bdd80c13a0b3 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 15 Jan 2024 17:52:12 +0100 Subject: [PATCH 08/35] move to iterator --- dlt/destinations/impl/sink/sink.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 1ef54ba962..a89e14f2cf 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from types import TracebackType -from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, List +from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, Iterable from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems @@ -44,15 +44,18 @@ def __init__( # TODO: is this the correct way to tell dlt to retry this job in the next attempt? self._state: TLoadJobState = "running" try: - start_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0) - self.run(start_index) + current_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0) + for batch in self.run(current_index): + self.call_callable_with_items(batch) + current_index += len(batch) + self._job_execution_storage[self._parsed_file_name.file_id] = current_index self._state = "completed" except Exception as e: self._state = "retry" raise e @abstractmethod - def run(self, start_index: int) -> None: + def run(self, start_index: int) -> Iterable[TDataItems]: pass def call_callable_with_items(self, items: TDataItems) -> None: @@ -73,10 +76,6 @@ def call_callable_with_items(self, items: TDataItems) -> None: # call callable self._config.credentials.resolved_callable(coerced_items, self._table) - # if there was no exception we assume the callable call was successful and we advance the index - current_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0) - self._job_execution_storage[self._parsed_file_name.file_id] = current_index + len(items) - def state(self) -> TLoadJobState: return self._state @@ -85,7 +84,7 @@ def exception(self) -> str: class SinkParquetLoadJob(SinkLoadJob): - def run(self, start_index: int) -> None: + def run(self, start_index: int) -> Iterable[TDataItems]: # stream items from dlt.common.libs.pyarrow import pyarrow @@ -101,11 +100,11 @@ def run(self, start_index: int) -> None: start_batch -= 1 continue batch = record_batch.to_pylist() - self.call_callable_with_items(batch) + yield batch class SinkJsonlLoadJob(SinkLoadJob): - def run(self, start_index: int) -> None: + def run(self, start_index: int) -> Iterable[TDataItems]: current_batch: TDataItems = [] # stream items @@ -117,9 +116,9 @@ def run(self, start_index: int) -> None: continue current_batch.append(json.loads(line)) if len(current_batch) == self._config.batch_size: - self.call_callable_with_items(current_batch) + yield current_batch current_batch = [] - self.call_callable_with_items(current_batch) + yield current_batch # class SinkInsertValueslLoadJob(SinkLoadJob): From 5527689cd51182b9faa82193fe47df7ca71f0c70 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 25 Jan 2024 20:06:57 +0100 Subject: [PATCH 09/35] persist sink load state in pipeline state --- .../specs/config_section_context.py | 2 + dlt/common/pipeline.py | 51 +++++++++++- dlt/destinations/impl/sink/sink.py | 83 ++++--------------- dlt/load/load.py | 16 +++- dlt/pipeline/pipeline.py | 20 ++++- tests/load/sink/test_sink.py | 15 +++- 6 files changed, 112 insertions(+), 75 deletions(-) diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index a656a2b0fe..b4d0bc7731 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -12,6 +12,7 @@ class ConfigSectionContext(ContainerInjectableContext): sections: Tuple[str, ...] = () merge_style: TMergeFunc = None source_state_key: str = None + destination_state_key: str = None def merge(self, existing: "ConfigSectionContext") -> None: """Merges existing context into incoming using a merge style function""" @@ -79,4 +80,5 @@ def __init__( sections: Tuple[str, ...] = (), merge_style: TMergeFunc = None, source_state_key: str = None, + destination_state_key: str = None, ) -> None: ... diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 6b7b308b44..1971b83410 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -470,6 +470,7 @@ class TPipelineState(TypedDict, total=False): """A section of state that is not synchronized with the destination and does not participate in change merging and version control""" sources: NotRequired[Dict[str, Dict[str, Any]]] + destinations: NotRequired[Dict[str, Dict[str, Any]]] class TSourceState(TPipelineState): @@ -594,9 +595,13 @@ class StateInjectableContext(ContainerInjectableContext): can_create_default: ClassVar[bool] = False + commit: Optional[Callable[[], None]] = None + if TYPE_CHECKING: - def __init__(self, state: TPipelineState = None) -> None: ... + def __init__( + self, state: TPipelineState = None, commit: Optional[Callable[[], None]] = None + ) -> None: ... def pipeline_state( @@ -679,6 +684,50 @@ def source_state() -> DictStrAny: _last_full_state: TPipelineState = None +def destination_state() -> DictStrAny: + container = Container() + + # get the destination name from the section context + destination_state_key: str = None + with contextlib.suppress(ContextDefaultCannotBeCreated): + sections_context = container[ConfigSectionContext] + destination_state_key = sections_context.destination_state_key + + if not destination_state_key: + raise SourceSectionNotAvailable() + + state, _ = pipeline_state(Container()) + + destination_state: DictStrAny = state.setdefault("destinations", {}).setdefault( + destination_state_key, {} + ) + return destination_state + + +def reset_destination_state() -> None: + container = Container() + + # get the destination name from the section context + destination_state_key: str = None + with contextlib.suppress(ContextDefaultCannotBeCreated): + sections_context = container[ConfigSectionContext] + destination_state_key = sections_context.destination_state_key + + if not destination_state_key: + raise SourceSectionNotAvailable() + + state, _ = pipeline_state(Container()) + + state.setdefault("destinations", {}).pop(destination_state_key) + + +def commit_pipeline_state() -> None: + container = Container() + # get injected state if present. injected state is typically "managed" so changes will be persisted + state_ctx = container[StateInjectableContext] + state_ctx.commit() + + def _delete_source_state_keys( key: TAnyJsonPath, source_state_: Optional[DictStrAny] = None, / ) -> None: diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index a89e14f2cf..b9e0904e80 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -5,8 +5,10 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems from dlt.common import json +from dlt.common.configuration.container import Container +from dlt.common.pipeline import StateInjectableContext +from dlt.common.pipeline import destination_state, reset_destination_state, commit_pipeline_state -from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage @@ -21,10 +23,6 @@ from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable -# TODO: implement proper state storage somewhere, can this somehow go into the loadpackage? -job_execution_storage: Dict[str, int] = {} - - class SinkLoadJob(LoadJob, ABC): def __init__( self, @@ -32,27 +30,28 @@ def __init__( file_path: str, config: SinkClientConfiguration, schema: Schema, - job_execution_storage: Dict[str, int], + load_state: Dict[str, int], ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config self._table = table self._schema = schema - self._job_execution_storage = job_execution_storage - # TODO: is this the correct way to tell dlt to retry this job in the next attempt? self._state: TLoadJobState = "running" try: - current_index = self._job_execution_storage.get(self._parsed_file_name.file_id, 0) + current_index = load_state.get(self._parsed_file_name.file_id, 0) for batch in self.run(current_index): self.call_callable_with_items(batch) current_index += len(batch) - self._job_execution_storage[self._parsed_file_name.file_id] = current_index + load_state[self._parsed_file_name.file_id] = current_index self._state = "completed" except Exception as e: self._state = "retry" raise e + finally: + # save progress + commit_pipeline_state() @abstractmethod def run(self, start_index: int) -> Iterable[TDataItems]: @@ -121,52 +120,6 @@ def run(self, start_index: int) -> Iterable[TDataItems]: yield current_batch -# class SinkInsertValueslLoadJob(SinkLoadJob): -# def run(self) -> None: -# from dlt.common import json - -# # stream items -# with FileStorage.open_zipsafe_ro(self._file_path) as f: -# header = f.readline().strip() -# values_mark = f.readline() - -# # properly formatted file has a values marker at the beginning -# assert values_mark == "VALUES\n" - -# # extract column names -# assert header.startswith("INSERT INTO") and header.endswith(")") -# header = header[15:-1] -# column_names = header.split(",") - -# # build batches -# current_batch: TDataItems = [] -# current_row: str = "" -# for line in f: -# current_row += line -# if line.endswith(");"): -# current_row = current_row[1:-2] -# elif line.endswith("),\n"): -# current_row = current_row[1:-3] -# else: -# continue - -# values = current_row.split(",") -# values = [None if v == "NULL" else v for v in values] -# current_row = "" -# print(values) -# print(current_row) - -# # zip and send to callable -# current_batch.append(dict(zip(column_names, values))) -# d = dict(zip(column_names, values)) -# print(json.dumps(d, pretty=True)) -# if len(current_batch) == self._config.batch_size: -# self.call_callable_with_items(current_batch) -# current_batch = [] - -# self.call_callable_with_items(current_batch) - - class SinkClient(JobClientBase): """Sink Client""" @@ -175,8 +128,6 @@ class SinkClient(JobClientBase): def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: super().__init__(schema, config) self.config: SinkClientConfiguration = config - global job_execution_storage - self.job_execution_storage = job_execution_storage def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: pass @@ -193,23 +144,21 @@ def update_stored_schema( return super().update_stored_schema(only_tables, expected_update) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + load_state = destination_state().setdefault(load_id, {}) if file_path.endswith("parquet"): - return SinkParquetLoadJob( - table, file_path, self.config, self.schema, job_execution_storage - ) + return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state) if file_path.endswith("jsonl"): - return SinkJsonlLoadJob( - table, file_path, self.config, self.schema, job_execution_storage - ) - # if file_path.endswith("insert_values"): - # return SinkInsertValueslLoadJob(table, file_path, self.config, self.schema) + return SinkJsonlLoadJob(table, file_path, self.config, self.schema, load_state) return None def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") def complete_load(self, load_id: str) -> None: - pass + # pop all state for this load on success + state = destination_state() + state.pop(load_id, None) + commit_pipeline_state() def __enter__(self) -> "SinkClient": return self diff --git a/dlt/load/load.py b/dlt/load/load.py index b0b52d61d6..beb52cc296 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -7,6 +7,7 @@ from dlt.common import sleep, logger from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo from dlt.common.schema.utils import get_child_tables, get_top_level_table @@ -35,6 +36,7 @@ SupportsStagingDestination, TDestination, ) +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.destinations.job_impl import EmptyLoadJob @@ -558,10 +560,16 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: # get top load id and mark as being processed with self.collector(f"Load {schema.name} in {load_id}"): - # the same load id may be processed across multiple runs - if not self.current_load_id: - self._step_info_start_load_id(load_id) - self.load_single_package(load_id, schema) + with inject_section( + ConfigSectionContext( + sections=(known_sections.LOAD,), + destination_state_key=self.destination.destination_name, + ) + ): + # the same load id may be processed across multiple runs + if not self.current_load_id: + self._step_info_start_load_id(load_id) + self.load_single_package(load_id, schema) return TRunMetrics(False, len(self.load_storage.list_normalized_packages())) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 73c8f076d1..613e4794ce 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -143,11 +143,20 @@ def decorator(f: TFun) -> TFun: def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # activate pipeline so right state is always provided self.activate() + # backup and restore state should_extract_state = may_extract_state and self.config.restore_from_destination with self.managed_state(extract_state=should_extract_state) as state: + # commit hook + def commit_state() -> None: + # save the state + bump_version_if_modified(state) + self._save_state(state) + # add the state to container as a context - with self._container.injectable_context(StateInjectableContext(state=state)): + with self._container.injectable_context( + StateInjectableContext(state=state, commit=commit_state) + ): return f(self, *args, **kwargs) return _wrap # type: ignore @@ -246,7 +255,14 @@ class Pipeline(SupportsPipeline): STATE_FILE: ClassVar[str] = "state.json" STATE_PROPS: ClassVar[List[str]] = list( set(get_type_hints(TPipelineState).keys()) - - {"sources", "destination_type", "destination_name", "staging_type", "staging_name"} + - { + "sources", + "destination_type", + "destination_name", + "staging_type", + "staging_name", + "destinations", + } ) LOCAL_STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineLocalState).keys()) DEFAULT_DATASET_SUFFIX: ClassVar[str] = "_dataset" diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 2be7af5513..d96fb2d9eb 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -234,6 +234,8 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: p.run([items(), items2()]) assert_items_in_range(calls["items"], 0, 100) assert_items_in_range(calls["items2"], 0, 100) + # destination state should be cleared after load + assert p.state["destinations"]["sink"] == {} # provoke errors calls = {} @@ -242,16 +244,25 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: with pytest.raises(PipelineStepFailed): p.run([items(), items2()]) - # partly loaded + # we should have data for one load id saved here + assert len(p.state["destinations"]["sink"]) == 1 + # get saved indexes + values = list(list(p.state["destinations"]["sink"].values())[0].values()) + + # partly loaded, pointers in state should be right if batch_size == 1: assert_items_in_range(calls["items"], 0, 25) assert_items_in_range(calls["items2"], 0, 45) + # one pointer for state, one for items, one for items2... + assert values == [1, 25, 45] elif batch_size == 10: assert_items_in_range(calls["items"], 0, 20) assert_items_in_range(calls["items2"], 0, 40) + assert values == [1, 20, 40] elif batch_size == 23: assert_items_in_range(calls["items"], 0, 23) assert_items_in_range(calls["items2"], 0, 23) + assert values == [1, 23, 23] else: raise AssertionError("Unknown batch size") @@ -260,6 +271,8 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: provoke_error = {} calls = {} p.load() + # state should be cleared again + assert p.state["destinations"]["sink"] == {} # both calls combined should have every item called just once assert_items_in_range(calls["items"] + first_calls["items"], 0, 100) From 0657034dd7876c9bfb4ad81a68281cb253c54949 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 30 Jan 2024 13:52:24 +0100 Subject: [PATCH 10/35] fix unrelated typo --- docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md b/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md index 7e26999f3c..8336fe850a 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md @@ -265,7 +265,7 @@ verified source. ```python load_data = pipedrive_source() # calls the source function - load_info = pipeline.run(load_info) #runs the pipeline with selected source configuration + load_info = pipeline.run(load_data) #runs the pipeline with selected source configuration print(load_info) ``` From b5db5b841735ae7a2e4b2ad01d02dd9b5637968e Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 30 Jan 2024 15:12:14 +0100 Subject: [PATCH 11/35] move sink state storage to loadpackage state --- dlt/common/pipeline.py | 63 +++++++++++------------------ dlt/common/storages/load_package.py | 30 ++++++++++++++ dlt/common/storages/load_storage.py | 8 ++++ dlt/destinations/impl/sink/sink.py | 25 ++++++------ dlt/load/load.py | 35 ++++++++++++---- dlt/pipeline/pipeline.py | 16 +++----- docs/website/docs/intro-snippets.py | 11 ++--- tests/load/sink/test_sink.py | 22 +++++----- 8 files changed, 124 insertions(+), 86 deletions(-) diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 1971b83410..a4d54eeb2a 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -595,12 +595,21 @@ class StateInjectableContext(ContainerInjectableContext): can_create_default: ClassVar[bool] = False - commit: Optional[Callable[[], None]] = None + if TYPE_CHECKING: + + def __init__(self, state: TPipelineState = None) -> None: ... + + +@configspec +class LoadPackageStateInjectableContext(ContainerInjectableContext): + state: DictStrAny + commit: Optional[Callable[[], None]] + can_create_default: ClassVar[bool] = False if TYPE_CHECKING: def __init__( - self, state: TPipelineState = None, commit: Optional[Callable[[], None]] = None + self, state: DictStrAny = None, commit: Optional[Callable[[], None]] = None ) -> None: ... @@ -684,47 +693,23 @@ def source_state() -> DictStrAny: _last_full_state: TPipelineState = None -def destination_state() -> DictStrAny: +def load_package_state() -> DictStrAny: container = Container() - - # get the destination name from the section context - destination_state_key: str = None - with contextlib.suppress(ContextDefaultCannotBeCreated): - sections_context = container[ConfigSectionContext] - destination_state_key = sections_context.destination_state_key - - if not destination_state_key: - raise SourceSectionNotAvailable() - - state, _ = pipeline_state(Container()) - - destination_state: DictStrAny = state.setdefault("destinations", {}).setdefault( - destination_state_key, {} - ) - return destination_state - - -def reset_destination_state() -> None: - container = Container() - - # get the destination name from the section context - destination_state_key: str = None - with contextlib.suppress(ContextDefaultCannotBeCreated): - sections_context = container[ConfigSectionContext] - destination_state_key = sections_context.destination_state_key - - if not destination_state_key: - raise SourceSectionNotAvailable() - - state, _ = pipeline_state(Container()) - - state.setdefault("destinations", {}).pop(destination_state_key) + # get injected state if present. injected load package state is typically "managed" so changes will be persisted + # if you need to save the load package state during a load, you need to call commit_load_package_state + try: + state_ctx = container[LoadPackageStateInjectableContext] + except ContextDefaultCannotBeCreated: + raise Exception("Load package state not available") + return state_ctx.state -def commit_pipeline_state() -> None: +def commit_load_package_state() -> None: container = Container() - # get injected state if present. injected state is typically "managed" so changes will be persisted - state_ctx = container[StateInjectableContext] + try: + state_ctx = container[LoadPackageStateInjectableContext] + except ContextDefaultCannotBeCreated: + raise Exception("Load package state not available") state_ctx.commit() diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 2860364cd0..ebd7feaa45 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -16,6 +16,7 @@ Set, get_args, cast, + Any, ) from dlt.common import pendulum, json @@ -200,6 +201,9 @@ class PackageStorage: PACKAGE_COMPLETED_FILE_NAME = ( # completed package marker file, currently only to store data with os.stat "package_completed.json" ) + LOAD_PACKAGE_STATE_FILE_NAME = ( # internal state of the load package, will not be synced to the destination + "load_package_state.json" + ) def __init__(self, storage: FileStorage, initial_state: TLoadPackageState) -> None: """Creates storage that manages load packages with root at `storage` and initial package state `initial_state`""" @@ -335,6 +339,8 @@ def create_package(self, load_id: str) -> None: self.storage.create_folder(os.path.join(load_id, PackageStorage.COMPLETED_JOBS_FOLDER)) self.storage.create_folder(os.path.join(load_id, PackageStorage.FAILED_JOBS_FOLDER)) self.storage.create_folder(os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER)) + # create new (and empty) state + self.save_load_package_state(load_id, {}) def complete_loading_package(self, load_id: str, load_state: TLoadPackageState) -> str: """Completes loading the package by writing marker file with`package_state. Returns path to the completed package""" @@ -380,6 +386,30 @@ def save_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> Non ) as f: json.dump(schema_update, f) + # + # Loadpackage state + # + def get_load_package_state(self, load_id: str) -> DictStrAny: + package_path = self.get_package_path(load_id) + if not self.storage.has_folder(package_path): + raise LoadPackageNotFound(load_id) + try: + state = self.storage.load( + os.path.join(package_path, PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME) + ) + return cast(DictStrAny, json.loads(state)) + except FileNotFoundError: + return {} + + def save_load_package_state(self, load_id: str, state: DictStrAny) -> None: + package_path = self.get_package_path(load_id) + if not self.storage.has_folder(package_path): + raise LoadPackageNotFound(load_id) + self.storage.save( + os.path.join(package_path, PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME), + json.dumps(state), + ) + # # Get package info # diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index a83502cb9b..926d13f732 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -1,6 +1,7 @@ from os.path import join from typing import Iterable, Optional, Sequence +from dlt.common.typing import DictStrAny from dlt.common import json from dlt.common.configuration import known_sections from dlt.common.configuration.inject import with_config @@ -184,3 +185,10 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: return self.loaded_packages.get_load_package_info(load_id) except LoadPackageNotFound: return self.normalized_packages.get_load_package_info(load_id) + + def get_load_package_state(self, load_id: str) -> DictStrAny: + """Gets state of normlized or loaded package with given load_id, all jobs and their statuses.""" + try: + return self.loaded_packages.get_load_package_state(load_id) + except LoadPackageNotFound: + return self.normalized_packages.get_load_package_state(load_id) diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index b9e0904e80..21ad09f1f5 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -1,13 +1,11 @@ from abc import ABC, abstractmethod from types import TracebackType -from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, Iterable +from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems from dlt.common import json -from dlt.common.configuration.container import Container -from dlt.common.pipeline import StateInjectableContext -from dlt.common.pipeline import destination_state, reset_destination_state, commit_pipeline_state +from dlt.common.pipeline import load_package_state, commit_load_package_state from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TTableSchema @@ -30,7 +28,7 @@ def __init__( file_path: str, config: SinkClientConfiguration, schema: Schema, - load_state: Dict[str, int], + load_package_state: Dict[str, int], ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path @@ -39,19 +37,21 @@ def __init__( self._schema = schema self._state: TLoadJobState = "running" + self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" try: - current_index = load_state.get(self._parsed_file_name.file_id, 0) + current_index = load_package_state.get(self._storage_id, 0) for batch in self.run(current_index): self.call_callable_with_items(batch) current_index += len(batch) - load_state[self._parsed_file_name.file_id] = current_index + load_package_state[self._storage_id] = current_index + self._state = "completed" except Exception as e: self._state = "retry" raise e finally: # save progress - commit_pipeline_state() + commit_load_package_state() @abstractmethod def run(self, start_index: int) -> Iterable[TDataItems]: @@ -144,7 +144,8 @@ def update_stored_schema( return super().update_stored_schema(only_tables, expected_update) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - load_state = destination_state().setdefault(load_id, {}) + # save our state in destination name scope + load_state = load_package_state().setdefault(self.config.destination_name, {}) if file_path.endswith("parquet"): return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state) if file_path.endswith("jsonl"): @@ -156,9 +157,9 @@ def restore_file_load(self, file_path: str) -> LoadJob: def complete_load(self, load_id: str) -> None: # pop all state for this load on success - state = destination_state() - state.pop(load_id, None) - commit_pipeline_state() + state = load_package_state() + state.pop(self.config.destination_name, None) + commit_load_package_state() def __enter__(self) -> "SinkClient": return self diff --git a/dlt/load/load.py b/dlt/load/load.py index beb52cc296..024e42c745 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,4 +1,4 @@ -import contextlib +import contextlib, threading from functools import reduce import datetime # noqa: 251 from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Callable @@ -9,7 +9,13 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.accessors import config -from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo +from dlt.common.pipeline import ( + LoadInfo, + LoadMetrics, + SupportsPipeline, + WithStepInfo, + LoadPackageStateInjectableContext, +) from dlt.common.schema.utils import get_child_tables, get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor @@ -20,8 +26,10 @@ DestinationTerminalException, DestinationTransientException, ) +from dlt.common.configuration.container import Container + from dlt.common.schema import Schema, TSchemaTables -from dlt.common.schema.typing import TTableSchema, TWriteDisposition +from dlt.common.schema.typing import TTableSchema from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, @@ -520,7 +528,7 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: failed_job.job_file_info.job_id(), failed_job.failed_message, ) - # possibly raise on too many retires + # possibly raise on too many retries if self.config.raise_on_max_retries: for new_job in package_info.jobs["new_jobs"]: r_c = new_job.job_file_info.retry_count @@ -558,12 +566,23 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: schema = self.load_storage.normalized_packages.load_schema(load_id) logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") + # prepare load package state context + load_package_state = self.load_storage.normalized_packages.get_load_package_state(load_id) + state_save_lock = threading.Lock() + + def commit_load_package_state() -> None: + with state_save_lock: + self.load_storage.normalized_packages.save_load_package_state( + load_id, load_package_state + ) + + container = Container() # get top load id and mark as being processed with self.collector(f"Load {schema.name} in {load_id}"): - with inject_section( - ConfigSectionContext( - sections=(known_sections.LOAD,), - destination_state_key=self.destination.destination_name, + with container.injectable_context( + LoadPackageStateInjectableContext( + state=load_package_state, + commit=commit_load_package_state, ) ): # the same load id may be processed across multiple runs diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 613e4794ce..5ccb518fe7 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -47,7 +47,7 @@ ) from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound -from dlt.common.typing import DictStrStr, TFun, TSecretValue, is_optional_type +from dlt.common.typing import DictStrAny, TFun, TSecretValue, is_optional_type from dlt.common.runners import pool_runner as runner from dlt.common.storages import ( LiveSchemaStorage, @@ -147,16 +147,8 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # backup and restore state should_extract_state = may_extract_state and self.config.restore_from_destination with self.managed_state(extract_state=should_extract_state) as state: - # commit hook - def commit_state() -> None: - # save the state - bump_version_if_modified(state) - self._save_state(state) - # add the state to container as a context - with self._container.injectable_context( - StateInjectableContext(state=state, commit=commit_state) - ): + with self._container.injectable_context(StateInjectableContext(state=state)): return f(self, *args, **kwargs) return _wrap # type: ignore @@ -840,6 +832,10 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: except LoadPackageNotFound: return self._get_normalize_storage().extracted_packages.get_load_package_info(load_id) + def get_load_package_state(self, load_id: str) -> DictStrAny: + """Returns information on extracted/normalized/completed package with given load_id, all jobs and their statuses.""" + return self._get_load_storage().get_load_package_state(load_id) + def list_failed_jobs_in_package(self, load_id: str) -> Sequence[LoadJobInfo]: """List all failed jobs and associated error messages for a specified `load_id`""" return self._get_load_storage().get_load_package_info(load_id).jobs.get("failed_jobs", []) diff --git a/docs/website/docs/intro-snippets.py b/docs/website/docs/intro-snippets.py index 340a6ff262..f270dcee6e 100644 --- a/docs/website/docs/intro-snippets.py +++ b/docs/website/docs/intro-snippets.py @@ -18,14 +18,13 @@ def intro_snippet() -> None: response.raise_for_status() data.append(response.json()) # Extract, normalize, and load the data - load_info = pipeline.run(data, table_name='player') + load_info = pipeline.run(data, table_name="player") # @@@DLT_SNIPPET_END api assert_load_info(load_info) def csv_snippet() -> None: - # @@@DLT_SNIPPET_START csv import dlt import pandas as pd @@ -50,8 +49,8 @@ def csv_snippet() -> None: assert_load_info(load_info) -def db_snippet() -> None: +def db_snippet() -> None: # @@@DLT_SNIPPET_START db import dlt from sqlalchemy import create_engine @@ -74,13 +73,9 @@ def db_snippet() -> None: ) # Convert the rows into dictionaries on the fly with a map function - load_info = pipeline.run( - map(lambda row: dict(row._mapping), rows), - table_name="genome" - ) + load_info = pipeline.run(map(lambda row: dict(row._mapping), rows), table_name="genome") print(load_info) # @@@DLT_SNIPPET_END db assert_load_info(load_info) - diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index d96fb2d9eb..a82b7baef6 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -231,11 +231,11 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: # no errors are set, all items should be processed p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) - p.run([items(), items2()]) + load_id = p.run([items(), items2()]).loads_ids[0] assert_items_in_range(calls["items"], 0, 100) assert_items_in_range(calls["items2"], 0, 100) # destination state should be cleared after load - assert p.state["destinations"]["sink"] == {} + assert p.get_load_package_state(load_id) == {} # provoke errors calls = {} @@ -245,24 +245,27 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: p.run([items(), items2()]) # we should have data for one load id saved here - assert len(p.state["destinations"]["sink"]) == 1 - # get saved indexes - values = list(list(p.state["destinations"]["sink"].values())[0].values()) + load_id = p.list_normalized_load_packages()[0] + load_package_state = p.get_load_package_state(load_id) + + assert len(load_package_state) == 1 + # get saved indexes mapped to table (this test will only work for one job per table) + values = {k.split(".")[0]: v for k, v in list(load_package_state.values())[0].items()} # partly loaded, pointers in state should be right if batch_size == 1: assert_items_in_range(calls["items"], 0, 25) assert_items_in_range(calls["items2"], 0, 45) # one pointer for state, one for items, one for items2... - assert values == [1, 25, 45] + assert values == {"_dlt_pipeline_state": 1, "items": 25, "items2": 45} elif batch_size == 10: assert_items_in_range(calls["items"], 0, 20) assert_items_in_range(calls["items2"], 0, 40) - assert values == [1, 20, 40] + assert values == {"_dlt_pipeline_state": 1, "items": 20, "items2": 40} elif batch_size == 23: assert_items_in_range(calls["items"], 0, 23) assert_items_in_range(calls["items2"], 0, 23) - assert values == [1, 23, 23] + assert values == {"_dlt_pipeline_state": 1, "items": 23, "items2": 23} else: raise AssertionError("Unknown batch size") @@ -272,7 +275,8 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: calls = {} p.load() # state should be cleared again - assert p.state["destinations"]["sink"] == {} + load_package_state = p.get_load_package_state(load_id) + assert load_package_state == {} # both calls combined should have every item called just once assert_items_in_range(calls["items"] + first_calls["items"], 0, 100) From 189d24bb084b3c619fb04961a3460ef8cac47594 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 30 Jan 2024 15:33:12 +0100 Subject: [PATCH 12/35] additional pr fixes --- .../configuration/specs/config_section_context.py | 2 -- dlt/common/data_types/type_helpers.py | 2 +- dlt/destinations/decorators.py | 12 ++++++++++-- dlt/destinations/impl/sink/__init__.py | 5 ++++- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index b4d0bc7731..a656a2b0fe 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -12,7 +12,6 @@ class ConfigSectionContext(ContainerInjectableContext): sections: Tuple[str, ...] = () merge_style: TMergeFunc = None source_state_key: str = None - destination_state_key: str = None def merge(self, existing: "ConfigSectionContext") -> None: """Merges existing context into incoming using a merge style function""" @@ -80,5 +79,4 @@ def __init__( sections: Tuple[str, ...] = (), merge_style: TMergeFunc = None, source_state_key: str = None, - destination_state_key: str = None, ) -> None: ... diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index 9b26207cf1..800fa8a680 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -111,7 +111,7 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: try: return json.loads(value) except Exception: - pass + raise ValueError("Cannot load text as json for complex type") if to_type == "text": if from_type == "complex": diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index a99690cff7..a21e8eaca8 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -3,10 +3,18 @@ from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable from dlt.common.destination import TDestinationReferenceArg from dlt.common.destination import TLoaderFileFormat +from dlt.common.utils import get_callable_name -def sink(loader_file_format: TLoaderFileFormat = None, batch_size: int = 10) -> Any: +def sink( + loader_file_format: TLoaderFileFormat = None, batch_size: int = 10, name: str = None +) -> Any: def decorator(f: TSinkCallable) -> TDestinationReferenceArg: - return _sink(credentials=f, loader_file_format=loader_file_format, batch_size=batch_size) + nonlocal name + if name is None: + name = get_callable_name(f) + return _sink( + credentials=f, loader_file_format=loader_file_format, batch_size=batch_size, name=name + ) return decorator diff --git a/dlt/destinations/impl/sink/__init__.py b/dlt/destinations/impl/sink/__init__.py index a5f7bd268b..2902fb8b03 100644 --- a/dlt/destinations/impl/sink/__init__.py +++ b/dlt/destinations/impl/sink/__init__.py @@ -5,4 +5,7 @@ def capabilities( preferred_loader_file_format: TLoaderFileFormat = "parquet", ) -> DestinationCapabilitiesContext: - return DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) + caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) + caps.supports_ddl_transactions = False + caps.supports_transactions = False + return caps From 57ed0900921197a73c7cd3c1ee71c4ca8ba3a96f Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 30 Jan 2024 16:06:39 +0100 Subject: [PATCH 13/35] disable creating empty state file on loadpackage init --- dlt/common/storages/load_package.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index ebd7feaa45..ea8b3280f9 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -340,7 +340,7 @@ def create_package(self, load_id: str) -> None: self.storage.create_folder(os.path.join(load_id, PackageStorage.FAILED_JOBS_FOLDER)) self.storage.create_folder(os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER)) # create new (and empty) state - self.save_load_package_state(load_id, {}) + # self.save_load_package_state(load_id, {}) def complete_loading_package(self, load_id: str, load_state: TLoadPackageState) -> str: """Completes loading the package by writing marker file with`package_state. Returns path to the completed package""" From 4f53bc41d807e3a943320fff00beb350812705f9 Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 31 Jan 2024 17:09:02 +0100 Subject: [PATCH 14/35] add sink docs page --- .../docs/dlt-ecosystem/destinations/sink.md | 109 ++++++++++++++++++ docs/website/sidebars.js | 1 + 2 files changed, 110 insertions(+) create mode 100644 docs/website/docs/dlt-ecosystem/destinations/sink.md diff --git a/docs/website/docs/dlt-ecosystem/destinations/sink.md b/docs/website/docs/dlt-ecosystem/destinations/sink.md new file mode 100644 index 0000000000..161132b902 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/destinations/sink.md @@ -0,0 +1,109 @@ +--- +title: Sink / Reverse ETL +description: Sink function `dlt` destination for reverse ETL +keywords: [reverse etl, sink, function] +--- + +# Sink function / Reverse ETL + +## Install dlt for Sink / reverse ETL +** To install the DLT without additional dependencies ** +``` +pip install dlt +``` + +## Setup Guide +### 1. Initialize the dlt project + +Let's start by initializing a new dlt project as follows: + +```bash +dlt init chess sink +``` +> 💡 This command will initialize your pipeline with chess as the source and Sink as the destination. + +The above command generates several files and directories, including `.dlt/secrets.toml`. + +### 2. Set up a sink function for your pipeline +The sink destination differs from other destinations in that you do not need to provide connection credentials, but rather you provide a function which +gets called for all items loaded during a pipeline run or load operation. For the chess example, you can add the following lines at the top of the file. +With the @dlt.sink decorator you can convert any function that takes two arguments into a dlt destination. + +```python +from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema + +@dlt.sink(batch_size=10) +def sink(items: TDataItems, table: TTableSchema) -> None: + print(table["name"]) + print(items) +``` + +To enable this sink destination in your chess example, replace the line `destination='sink'` with `destination=sink` (without the quotes) to directly reference +the sink from your pipeline constructor. Now you can run your pipeline and see the output of all the items coming from the chess pipeline to your console. + +:::tip +1. You can also remove the typing information (TDataItems and TTableSchema) from this example, typing generally are useful to know the shape of the incoming objects though. +2. There are a few other ways for declaring sink functions for your pipeline described below. +::: + +## Sink decorator function and signature + +The full signature of the sink decorator and a function is + +```python +@dlt.sink(batch_size=10, loader_file_format="jsonl", name="my_sink") +def sink(items: TDataItems, table: TTableSchema) -> None: + ... +``` + +#### Decorator +* The `batch_size` parameter on the sink decorator defines how many items per function call are batched together and sent as an array. If batch_size is set to one, +there will be one item without an array per call. +* The `loader_file_format` parameter on the sink decorator defines in which format files are stored in the load package before being sent to the sink function, +this can be `jsonl` or `parquet`. +* The `name` parameter on the sink decorator defines the name of the destination that get's created by the sink decorator. + +#### Sink function +* The `items` parameter on the sink function contains the items being sent into the sink function. +* The `table` parameter contains the schema table the current call belongs to including all table hints and columns. For example the table name can be access with table["name"]. Keep in mind that dlt also created special tables prefixed with `_dlt` which you may want to ignore when processing data. + +## Sink destination state +The sink destination keeps a local record of how many DataItems were processed, so if you, for example, use the sink destination to push DataItems to a remote api, and this +api becomes unavailable during the load resulting in a failed dlt pipeline run, you can repeat the run of your pipeline at a later stage and the sink destination will continue +where it left of. For this reason it makes sense to choose a batch size that you can process in one transaction (say one api request or one database transaction) so that if this +request or transaction fail repeatedly you can repeat it at the next run without pushing duplicate data to your remote location. + +## Concurrency +Calls to the sink function by default will be executed on multiple threads, so you need to make sure you are not using any non-thread-safe nonlocal or global variables from outside +your sink function. If, for whichever reason, you need to have all calls be executed from the same thread, you can set the `workers` config variable of the load step to 1. For performance +reasons we recommend to keep the multithreaded approach and make sure that you, for example, are using threadsafe connection pools to a remote database or queue. + +## Referencing the sink function +There are multiple ways to reference the sink function you want to use. These are: + +```python +# file my_pipeline.py + +@dlt.sink(batch_size=10) +def local_sink_func(items: TDataItems, table: TTableSchema) -> None: + ... + +# reference function directly +p = dlt.pipeline(name="my_pipe", destination=local_sink_func) + +# fully qualified string to function location (can be used from config.toml or env vars) +p = dlt.pipeline(name="my_pipe", destination="sink", credentials="my_pipeline.local_sink_func") + +# via destination reference +p = dlt.pipeline(name="my_pipe", destination=Destination.from_reference("sink", credentials=local_sink_func, environment="staging")) +``` + +## Write disposition + +The sink destination will forward all normalized DataItems encountered during a pipeline run to the sink function, so there is no notion of write dispositions for the sink. + +## Staging support + +The sink destination does not support staging files in remote locations before being called at this time. If you need this feature, please let us know. + diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 2e23c5ca45..b24cb7b944 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -96,6 +96,7 @@ const sidebars = { 'dlt-ecosystem/destinations/motherduck', 'dlt-ecosystem/destinations/weaviate', 'dlt-ecosystem/destinations/qdrant', + 'dlt-ecosystem/destinations/sink', ] }, ], From c6c06ba7f7a715568059f167314fc7fc44e6513a Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 1 Feb 2024 20:39:46 +0100 Subject: [PATCH 15/35] small changes --- dlt/common/pipeline.py | 1 - dlt/destinations/impl/sink/configuration.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index a4d54eeb2a..5d9c62152b 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -470,7 +470,6 @@ class TPipelineState(TypedDict, total=False): """A section of state that is not synchronized with the destination and does not participate in change merging and version control""" sources: NotRequired[Dict[str, Dict[str, Any]]] - destinations: NotRequired[Dict[str, Dict[str, Any]]] class TSourceState(TPipelineState): diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index bdff6cf58d..bb9caab294 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -62,7 +62,7 @@ def on_resolved(self) -> None: class SinkClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = "sink" # type: ignore credentials: SinkClientCredentials = None - loader_file_format: TLoaderFileFormat = "parquet" + loader_file_format: TLoaderFileFormat = "jsonl" batch_size: int = 10 if TYPE_CHECKING: @@ -71,6 +71,6 @@ def __init__( self, *, credentials: Union[SinkClientCredentials, TSinkCallable, str] = None, - loader_file_format: TLoaderFileFormat = "parquet", + loader_file_format: TLoaderFileFormat = "jsonl", batch_size: int = 10, ) -> None: ... From 374b267fdf91ef20db341de7eda55efbe304081a Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 6 Feb 2024 15:27:06 +0100 Subject: [PATCH 16/35] make loadstorage state versioned and separate out common base functions --- dlt/common/pipeline.py | 13 +- dlt/common/storages/load_package.py | 82 +++++++++++-- dlt/common/storages/load_storage.py | 3 +- dlt/common/versioned_state.py | 45 +++++++ dlt/destinations/__init__.py | 2 +- dlt/destinations/impl/athena/athena.py | 16 ++- dlt/destinations/impl/bigquery/bigquery.py | 10 +- .../impl/databricks/databricks.py | 28 +++-- dlt/destinations/impl/sink/sink.py | 8 +- dlt/destinations/impl/snowflake/snowflake.py | 6 +- dlt/helpers/streamlit_helper.py | 4 +- dlt/pipeline/pipeline.py | 37 +++--- dlt/pipeline/state_sync.py | 113 +++++++++--------- docs/examples/chess_production/chess.py | 12 +- docs/examples/connector_x_arrow/load_arrow.py | 2 + docs/examples/google_sheets/google_sheets.py | 5 +- docs/examples/incremental_loading/zendesk.py | 8 +- docs/examples/nested_data/nested_data.py | 2 + .../pdf_to_weaviate/pdf_to_weaviate.py | 5 +- docs/examples/qdrant_zendesk/qdrant.py | 9 +- docs/examples/transformers/pokemon.py | 4 +- tests/load/pipeline/test_drop.py | 4 +- tests/load/pipeline/test_restore_state.py | 22 ++-- tests/load/sink/test_sink.py | 7 +- tests/pipeline/test_pipeline.py | 4 +- tests/pipeline/test_pipeline_state.py | 47 ++++---- 26 files changed, 317 insertions(+), 181 deletions(-) create mode 100644 dlt/common/versioned_state.py diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 5d9c62152b..0d53a4f3a8 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -45,6 +45,8 @@ from dlt.common.jsonpath import delete_matches, TAnyJsonPath from dlt.common.data_writers.writers import DataWriterMetrics, TLoaderFileFormat from dlt.common.utils import RowCounts, merge_row_counts +from dlt.common.versioned_state import TVersionedState +from dlt.common.storages.load_package import TLoadPackageState class _StepInfo(NamedTuple): @@ -448,7 +450,7 @@ class TPipelineLocalState(TypedDict, total=False): """Hash of state that was recently synced with destination""" -class TPipelineState(TypedDict, total=False): +class TPipelineState(TVersionedState, total=False): """Schema for a pipeline state that is stored within the pipeline working directory""" pipeline_name: str @@ -463,9 +465,6 @@ class TPipelineState(TypedDict, total=False): staging_type: Optional[str] # properties starting with _ are not automatically applied to pipeline object when state is restored - _state_version: int - _version_hash: str - _state_engine_version: int _local: TPipelineLocalState """A section of state that is not synchronized with the destination and does not participate in change merging and version control""" @@ -601,14 +600,14 @@ def __init__(self, state: TPipelineState = None) -> None: ... @configspec class LoadPackageStateInjectableContext(ContainerInjectableContext): - state: DictStrAny + state: TLoadPackageState commit: Optional[Callable[[], None]] can_create_default: ClassVar[bool] = False if TYPE_CHECKING: def __init__( - self, state: DictStrAny = None, commit: Optional[Callable[[], None]] = None + self, state: TLoadPackageState = None, commit: Optional[Callable[[], None]] = None ) -> None: ... @@ -692,7 +691,7 @@ def source_state() -> DictStrAny: _last_full_state: TPipelineState = None -def load_package_state() -> DictStrAny: +def load_package_state() -> TLoadPackageState: container = Container() # get injected state if present. injected load package state is typically "managed" so changes will be persisted # if you need to save the load package state during a load, you need to call commit_load_package_state diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index ea8b3280f9..b83eae0dc4 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -17,6 +17,7 @@ get_args, cast, Any, + Tuple, ) from dlt.common import pendulum, json @@ -27,13 +28,63 @@ from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns from dlt.common.storages import FileStorage from dlt.common.storages.exceptions import LoadPackageNotFound -from dlt.common.typing import DictStrAny, StrAny, SupportsHumanize +from dlt.common.typing import DictStrAny, SupportsHumanize from dlt.common.utils import flatten_list_or_items +from dlt.common.versioned_state import ( + generate_state_version_hash, + bump_state_version_if_modified, + TVersionedState, + default_versioned_state, +) +from typing_extensions import NotRequired + + +class TLoadPackageState(TVersionedState, total=False): + created: int + """Timestamp when the loadpackage was created""" + + """A section of state that does not participate in change merging and version control""" + destinations: NotRequired[Dict[str, Dict[str, Any]]] + """private space for destinations to store state relevant only to the load package""" + + +# allows to upgrade state when restored with a new version of state logic/schema +LOADPACKAGE_STATE_ENGINE_VERSION = 1 + + +def generate_loadpackage_state_version_hash(state: TLoadPackageState) -> str: + return generate_state_version_hash(state) + + +def bump_loadpackage_state_version_if_modified(state: TLoadPackageState) -> Tuple[int, str, str]: + return bump_state_version_if_modified(state) + + +def migrate_loadpackage_state( + state: DictStrAny, from_engine: int, to_engine: int +) -> TLoadPackageState: + if from_engine == to_engine: + return cast(TLoadPackageState, state) + + # check state engine + if from_engine != to_engine: + raise Exception("No upgrade path for loadpackage state") + + state["_state_engine_version"] = from_engine + return cast(TLoadPackageState, state) + + +def default_loadpackage_state() -> TLoadPackageState: + return { + **default_versioned_state(), + "_state_engine_version": LOADPACKAGE_STATE_ENGINE_VERSION, + } + # folders to manage load jobs in a single load package TJobState = Literal["new_jobs", "failed_jobs", "started_jobs", "completed_jobs"] WORKING_FOLDERS: Set[TJobState] = set(get_args(TJobState)) -TLoadPackageState = Literal["new", "extracted", "normalized", "loaded", "aborted"] +TLoadPackageStatus = Literal["new", "extracted", "normalized", "loaded", "aborted"] class ParsedLoadJobFileName(NamedTuple): @@ -125,7 +176,7 @@ def __str__(self) -> str: class _LoadPackageInfo(NamedTuple): load_id: str package_path: str - state: TLoadPackageState + state: TLoadPackageStatus schema: Schema schema_update: TSchemaTables completed_at: datetime.datetime @@ -205,7 +256,7 @@ class PackageStorage: "load_package_state.json" ) - def __init__(self, storage: FileStorage, initial_state: TLoadPackageState) -> None: + def __init__(self, storage: FileStorage, initial_state: TLoadPackageStatus) -> None: """Creates storage that manages load packages with root at `storage` and initial package state `initial_state`""" self.storage = storage self.initial_state = initial_state @@ -339,10 +390,13 @@ def create_package(self, load_id: str) -> None: self.storage.create_folder(os.path.join(load_id, PackageStorage.COMPLETED_JOBS_FOLDER)) self.storage.create_folder(os.path.join(load_id, PackageStorage.FAILED_JOBS_FOLDER)) self.storage.create_folder(os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER)) - # create new (and empty) state - # self.save_load_package_state(load_id, {}) + # ensure created timestamp is set in state when load package is created + state = self.get_load_package_state(load_id) + if not state.get("created"): + state["created"] = pendulum.now().timestamp() + self.save_load_package_state(load_id, state) - def complete_loading_package(self, load_id: str, load_state: TLoadPackageState) -> str: + def complete_loading_package(self, load_id: str, load_state: TLoadPackageStatus) -> str: """Completes loading the package by writing marker file with`package_state. Returns path to the completed package""" load_path = self.get_package_path(load_id) # save marker file @@ -389,22 +443,26 @@ def save_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> Non # # Loadpackage state # - def get_load_package_state(self, load_id: str) -> DictStrAny: + def get_load_package_state(self, load_id: str) -> TLoadPackageState: package_path = self.get_package_path(load_id) if not self.storage.has_folder(package_path): raise LoadPackageNotFound(load_id) try: - state = self.storage.load( + state_dump = self.storage.load( os.path.join(package_path, PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME) ) - return cast(DictStrAny, json.loads(state)) + state = json.loads(state_dump) + return migrate_loadpackage_state( + state, state["_state_engine_version"], LOADPACKAGE_STATE_ENGINE_VERSION + ) except FileNotFoundError: - return {} + return default_loadpackage_state() - def save_load_package_state(self, load_id: str, state: DictStrAny) -> None: + def save_load_package_state(self, load_id: str, state: TLoadPackageState) -> None: package_path = self.get_package_path(load_id) if not self.storage.has_folder(package_path): raise LoadPackageNotFound(load_id) + bump_loadpackage_state_version_if_modified(state) self.storage.save( os.path.join(package_path, PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME), json.dumps(state), diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 926d13f732..080a30b5a6 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -19,6 +19,7 @@ PackageStorage, ParsedLoadJobFileName, TJobState, + TLoadPackageState, ) from dlt.common.storages.exceptions import JobWithUnsupportedWriterException, LoadPackageNotFound @@ -186,7 +187,7 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: except LoadPackageNotFound: return self.normalized_packages.get_load_package_info(load_id) - def get_load_package_state(self, load_id: str) -> DictStrAny: + def get_load_package_state(self, load_id: str) -> TLoadPackageState: """Gets state of normlized or loaded package with given load_id, all jobs and their statuses.""" try: return self.loaded_packages.get_load_package_state(load_id) diff --git a/dlt/common/versioned_state.py b/dlt/common/versioned_state.py new file mode 100644 index 0000000000..a051a6660c --- /dev/null +++ b/dlt/common/versioned_state.py @@ -0,0 +1,45 @@ +import base64 +import hashlib +from copy import copy + +import datetime # noqa: 251 +from dlt.common import json +from typing import TypedDict, Dict, Any, List, Tuple, cast + + +class TVersionedState(TypedDict, total=False): + _state_version: int + _version_hash: str + _state_engine_version: int + + +def generate_state_version_hash(state: TVersionedState, exclude_attrs: List[str] = None) -> str: + # generates hash out of stored schema content, excluding hash itself, version and local state + state_copy = copy(state) + exclude_attrs = exclude_attrs or [] + exclude_attrs.extend(["_state_version", "_state_engine_version", "_version_hash"]) + for attr in exclude_attrs: + state_copy.pop(attr, None) # type: ignore + content = json.typed_dumpb(state_copy, sort_keys=True) # type: ignore + h = hashlib.sha3_256(content) + return base64.b64encode(h.digest()).decode("ascii") + + +def bump_state_version_if_modified( + state: TVersionedState, exclude_attrs: List[str] = None +) -> Tuple[int, str, str]: + """Bumps the `state` version and version hash if content modified, returns (new version, new hash, old hash) tuple""" + hash_ = generate_state_version_hash(state, exclude_attrs) + previous_hash = state.get("_version_hash") + if not previous_hash: + # if hash was not set, set it without bumping the version, that's the initial state + pass + elif hash_ != previous_hash: + state["_state_version"] += 1 + + state["_version_hash"] = hash_ + return state["_state_version"], hash_, previous_hash + + +def default_versioned_state() -> TVersionedState: + return {"_state_version": 0, "_state_engine_version": 1} diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index f6637eada9..4502b362d0 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -30,5 +30,5 @@ "weaviate", "synapse", "databricks", - "sink" + "sink", ] diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 91525d771c..96e7818d57 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -351,7 +351,9 @@ def _from_db_type( return self.type_mapper.from_db_type(hive_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" + return ( + f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" + ) def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool @@ -376,19 +378,15 @@ def _get_table_update_sql( # use qualified table names qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name) if is_iceberg and not generate_alter: - sql.append( - f"""CREATE TABLE {qualified_table_name} + sql.append(f"""CREATE TABLE {qualified_table_name} ({columns}) LOCATION '{location}' - TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""" - ) + TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""") elif not generate_alter: - sql.append( - f"""CREATE EXTERNAL TABLE {qualified_table_name} + sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name} ({columns}) STORED AS PARQUET - LOCATION '{location}';""" - ) + LOCATION '{location}';""") # alter table to add new columns at the end else: sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""") diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 1058b1d2c9..16b5d82c61 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -252,9 +252,9 @@ def _get_table_update_sql( elif (c := partition_list[0])["data_type"] == "date": sql[0] = f"{sql[0]}\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": - sql[ - 0 - ] = f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + sql[0] = ( + f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + ) # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. # The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range. @@ -272,7 +272,9 @@ def _get_table_update_sql( def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) - return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" + return ( + f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" + ) def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: schema_table: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index b5a404302f..07e827cd28 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -166,12 +166,14 @@ def __init__( else: raise LoadJobTerminalException( file_path, - f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and azure buckets are supported", + f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" + " azure buckets are supported", ) else: raise LoadJobTerminalException( file_path, - "Cannot load from local file. Databricks does not support loading from local files. Configure staging with an s3 or azure storage bucket.", + "Cannot load from local file. Databricks does not support loading from local files." + " Configure staging with an s3 or azure storage bucket.", ) # decide on source format, stage_file_path will either be a local file or a bucket path @@ -181,27 +183,33 @@ def __init__( if not config.get("data_writer.disable_compression"): raise LoadJobTerminalException( file_path, - "Databricks loader does not support gzip compressed JSON files. Please disable compression in the data writer configuration: https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", + "Databricks loader does not support gzip compressed JSON files. Please disable" + " compression in the data writer configuration:" + " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) if table_schema_has_type(table, "decimal"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load DECIMAL type columns from json files. Switch to parquet format to load decimals.", + "Databricks loader cannot load DECIMAL type columns from json files. Switch to" + " parquet format to load decimals.", ) if table_schema_has_type(table, "binary"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load BINARY type columns from json files. Switch to parquet format to load byte values.", + "Databricks loader cannot load BINARY type columns from json files. Switch to" + " parquet format to load byte values.", ) if table_schema_has_type(table, "complex"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load complex columns (lists and dicts) from json files. Switch to parquet format to load complex types.", + "Databricks loader cannot load complex columns (lists and dicts) from json" + " files. Switch to parquet format to load complex types.", ) if table_schema_has_type(table, "date"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load DATE type columns from json files. Switch to parquet format to load dates.", + "Databricks loader cannot load DATE type columns from json files. Switch to" + " parquet format to load dates.", ) source_format = "JSON" @@ -311,7 +319,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() - fields[ - 1 - ] = "full_data_type" # Override because this is the only way to get data type with precision + fields[1] = ( # Override because this is the only way to get data type with precision + "full_data_type" + ) return fields diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 21ad09f1f5..3b3396d3f6 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -145,7 +145,11 @@ def update_stored_schema( def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: # save our state in destination name scope - load_state = load_package_state().setdefault(self.config.destination_name, {}) + load_state = ( + load_package_state() + .setdefault("destinations", {}) + .setdefault(self.config.destination_name, {}) + ) if file_path.endswith("parquet"): return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state) if file_path.endswith("jsonl"): @@ -158,7 +162,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: def complete_load(self, load_id: str) -> None: # pop all state for this load on success state = load_package_state() - state.pop(self.config.destination_name, None) + state["destinations"].pop(self.config.destination_name, None) commit_load_package_state() def __enter__(self) -> "SinkClient": diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index fb51ab9d36..7fafbf83b7 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -175,15 +175,13 @@ def __init__( f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' " AUTO_COMPRESS = FALSE" ) - client.execute_sql( - f"""COPY INTO {qualified_table_name} + client.execute_sql(f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILE_FORMAT = {source_format} MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' - """ - ) + """) if stage_file_path and not keep_staged_files: client.execute_sql(f"REMOVE {stage_file_path}") diff --git a/dlt/helpers/streamlit_helper.py b/dlt/helpers/streamlit_helper.py index d3e194b18d..2af5ef12a8 100644 --- a/dlt/helpers/streamlit_helper.py +++ b/dlt/helpers/streamlit_helper.py @@ -12,7 +12,7 @@ from dlt.common.libs.pandas import pandas as pd from dlt.pipeline import Pipeline from dlt.pipeline.exceptions import CannotRestorePipelineException, SqlClientNotAvailable -from dlt.pipeline.state_sync import load_state_from_destination +from dlt.pipeline.state_sync import load_pipeline_state_from_destination try: import streamlit as st @@ -190,7 +190,7 @@ def _query_data_live(query: str, schema_name: str = None) -> pd.DataFrame: st.header("Pipeline state info") with pipeline.destination_client() as client: if isinstance(client, WithStateSync): - remote_state = load_state_from_destination(pipeline.pipeline_name, client) + remote_state = load_pipeline_state_from_destination(pipeline.pipeline_name, client) local_state = pipeline.state col1, col2 = st.columns(2) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 06d8aab537..53b043dcc2 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -126,15 +126,17 @@ ) from dlt.pipeline.typing import TPipelineStep from dlt.pipeline.state_sync import ( - STATE_ENGINE_VERSION, - bump_version_if_modified, - load_state_from_destination, - migrate_state, + PIPELINE_STATE_ENGINE_VERSION, + bump_pipeline_state_version_if_modified, + load_pipeline_state_from_destination, + migrate_pipeline_state, state_resource, json_encode_state, json_decode_state, + default_pipeline_state, ) from dlt.pipeline.warnings import credentials_argument_deprecated +from dlt.common.storages.load_package import TLoadPackageState def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: @@ -732,7 +734,7 @@ def sync_destination( # write the state back self._props_to_state(state) - bump_version_if_modified(state) + bump_pipeline_state_version_if_modified(state) self._save_state(state) except Exception as ex: raise PipelineStepFailed(self, "sync", None, ex, None) from ex @@ -832,7 +834,7 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: except LoadPackageNotFound: return self._get_normalize_storage().extracted_packages.get_load_package_info(load_id) - def get_load_package_state(self, load_id: str) -> DictStrAny: + def get_load_package_state(self, load_id: str) -> TLoadPackageState: """Returns information on extracted/normalized/completed package with given load_id, all jobs and their statuses.""" return self._get_load_storage().get_load_package_state(load_id) @@ -1175,9 +1177,9 @@ def _set_context(self, is_active: bool) -> None: # set destination context on activation if self.destination: # inject capabilities context - self._container[ - DestinationCapabilitiesContext - ] = self._get_destination_capabilities() + self._container[DestinationCapabilitiesContext] = ( + self._get_destination_capabilities() + ) else: # remove destination context on deactivation if DestinationCapabilitiesContext in self._container: @@ -1343,16 +1345,15 @@ def _get_step_info(self, step: WithStepInfo[TStepMetrics, TStepInfo]) -> TStepIn def _get_state(self) -> TPipelineState: try: state = json_decode_state(self._pipeline_storage.load(Pipeline.STATE_FILE)) - return migrate_state( - self.pipeline_name, state, state["_state_engine_version"], STATE_ENGINE_VERSION + return migrate_pipeline_state( + self.pipeline_name, + state, + state["_state_engine_version"], + PIPELINE_STATE_ENGINE_VERSION, ) except FileNotFoundError: # do not set the state hash, this will happen on first merge - return { - "_state_version": 0, - "_state_engine_version": STATE_ENGINE_VERSION, - "_local": {"first_run": True}, - } + return default_pipeline_state() # state["_version_hash"] = generate_version_hash(state) # return state @@ -1382,7 +1383,7 @@ def _restore_state_from_destination(self) -> Optional[TPipelineState]: schema = Schema(schema_name) with self._get_destination_clients(schema)[0] as job_client: if isinstance(job_client, WithStateSync): - state = load_state_from_destination(self.pipeline_name, job_client) + state = load_pipeline_state_from_destination(self.pipeline_name, job_client) if state is None: logger.info( "The state was not found in the destination" @@ -1516,7 +1517,7 @@ def _bump_version_and_extract_state( Storage will be created on demand. In that case the extracted package will be immediately committed. """ - _, hash_, _ = bump_version_if_modified(self._props_to_state(state)) + _, hash_, _ = bump_pipeline_state_version_if_modified(self._props_to_state(state)) should_extract = hash_ != state["_local"].get("_last_extracted_hash") if should_extract and extract_state: data = state_resource(state) diff --git a/dlt/pipeline/state_sync.py b/dlt/pipeline/state_sync.py index fa3939969b..8c72a218a4 100644 --- a/dlt/pipeline/state_sync.py +++ b/dlt/pipeline/state_sync.py @@ -1,25 +1,28 @@ -import base64 import binascii from copy import copy -import hashlib -from typing import Tuple, cast +from typing import Tuple, cast, List import pendulum import dlt from dlt.common import json -from dlt.common.pipeline import TPipelineState from dlt.common.typing import DictStrAny from dlt.common.schema.typing import STATE_TABLE_NAME, TTableSchemaColumns from dlt.common.destination.reference import WithStateSync, Destination from dlt.common.utils import compressed_b64decode, compressed_b64encode +from dlt.common.versioned_state import ( + generate_state_version_hash, + bump_state_version_if_modified, + default_versioned_state, +) +from dlt.common.pipeline import TPipelineState from dlt.extract import DltResource -from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException +from dlt.pipeline.exceptions import ( + PipelineStateEngineNoUpgradePathException, +) - -# allows to upgrade state when restored with a new version of state logic/schema -STATE_ENGINE_VERSION = 4 +PIPELINE_STATE_ENGINE_VERSION = 4 # state table columns STATE_TABLE_COLUMNS: TTableSchemaColumns = { @@ -57,59 +60,15 @@ def decompress_state(state_str: str) -> DictStrAny: return json.typed_loadb(state_bytes) # type: ignore[no-any-return] -def generate_version_hash(state: TPipelineState) -> str: - # generates hash out of stored schema content, excluding hash itself, version and local state - state_copy = copy(state) - state_copy.pop("_state_version", None) - state_copy.pop("_state_engine_version", None) - state_copy.pop("_version_hash", None) - state_copy.pop("_local", None) - content = json.typed_dumpb(state_copy, sort_keys=True) - h = hashlib.sha3_256(content) - return base64.b64encode(h.digest()).decode("ascii") - +def generate_pipeline_state_version_hash(state: TPipelineState) -> str: + return generate_state_version_hash(state, exclude_attrs=["_local"]) -def bump_version_if_modified(state: TPipelineState) -> Tuple[int, str, str]: - """Bumps the `state` version and version hash if content modified, returns (new version, new hash, old hash) tuple""" - hash_ = generate_version_hash(state) - previous_hash = state.get("_version_hash") - if not previous_hash: - # if hash was not set, set it without bumping the version, that's initial schema - pass - elif hash_ != previous_hash: - state["_state_version"] += 1 - state["_version_hash"] = hash_ - return state["_state_version"], hash_, previous_hash +def bump_pipeline_state_version_if_modified(state: TPipelineState) -> Tuple[int, str, str]: + return bump_state_version_if_modified(state, exclude_attrs=["_local"]) -def state_resource(state: TPipelineState) -> DltResource: - state = copy(state) - state.pop("_local") - state_str = compress_state(state) - state_doc = { - "version": state["_state_version"], - "engine_version": state["_state_engine_version"], - "pipeline_name": state["pipeline_name"], - "state": state_str, - "created_at": pendulum.now(), - "version_hash": state["_version_hash"], - } - return dlt.resource( - [state_doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS - ) - - -def load_state_from_destination(pipeline_name: str, client: WithStateSync) -> TPipelineState: - # NOTE: if dataset or table holding state does not exist, the sql_client will rise DestinationUndefinedEntity. caller must handle this - state = client.get_stored_state(pipeline_name) - if not state: - return None - s = decompress_state(state.state) - return migrate_state(pipeline_name, s, s["_state_engine_version"], STATE_ENGINE_VERSION) - - -def migrate_state( +def migrate_pipeline_state( pipeline_name: str, state: DictStrAny, from_engine: int, to_engine: int ) -> TPipelineState: if from_engine == to_engine: @@ -119,7 +78,7 @@ def migrate_state( from_engine = 2 if from_engine == 2 and to_engine > 2: # you may want to recompute hash - state["_version_hash"] = generate_version_hash(state) # type: ignore[arg-type] + state["_version_hash"] = generate_pipeline_state_version_hash(state) # type: ignore[arg-type] from_engine = 3 if from_engine == 3 and to_engine > 3: if state.get("destination"): @@ -139,3 +98,41 @@ def migrate_state( ) state["_state_engine_version"] = from_engine return cast(TPipelineState, state) + + +def state_resource(state: TPipelineState) -> DltResource: + state = copy(state) + state.pop("_local") + state_str = compress_state(state) + state_doc = { + "version": state["_state_version"], + "engine_version": state["_state_engine_version"], + "pipeline_name": state["pipeline_name"], + "state": state_str, + "created_at": pendulum.now(), + "version_hash": state["_version_hash"], + } + return dlt.resource( + [state_doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS + ) + + +def load_pipeline_state_from_destination( + pipeline_name: str, client: WithStateSync +) -> TPipelineState: + # NOTE: if dataset or table holding state does not exist, the sql_client will rise DestinationUndefinedEntity. caller must handle this + state = client.get_stored_state(pipeline_name) + if not state: + return None + s = decompress_state(state.state) + return migrate_pipeline_state( + pipeline_name, s, s["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION + ) + + +def default_pipeline_state() -> TPipelineState: + return { + **default_versioned_state(), + "_state_engine_version": PIPELINE_STATE_ENGINE_VERSION, + "_local": {"first_run": True}, + } diff --git a/docs/examples/chess_production/chess.py b/docs/examples/chess_production/chess.py index 2e85805781..f7c5849e57 100644 --- a/docs/examples/chess_production/chess.py +++ b/docs/examples/chess_production/chess.py @@ -6,6 +6,7 @@ from dlt.common.typing import StrAny, TDataItems from dlt.sources.helpers.requests import client + @dlt.source def chess( chess_url: str = dlt.config.value, @@ -59,6 +60,7 @@ def players_games(username: Any) -> Iterator[TDataItems]: MAX_PLAYERS = 5 + def load_data_with_retry(pipeline, data): try: for attempt in Retrying( @@ -68,9 +70,7 @@ def load_data_with_retry(pipeline, data): reraise=True, ): with attempt: - logger.info( - f"Running the pipeline, attempt={attempt.retry_state.attempt_number}" - ) + logger.info(f"Running the pipeline, attempt={attempt.retry_state.attempt_number}") load_info = pipeline.run(data) logger.info(str(load_info)) @@ -92,9 +92,7 @@ def load_data_with_retry(pipeline, data): # print the information on the first load package and all jobs inside logger.info(f"First load package info: {load_info.load_packages[0]}") # print the information on the first completed job in first load package - logger.info( - f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}" - ) + logger.info(f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}") # check for schema updates: schema_updates = [p.schema_update for p in load_info.load_packages] @@ -152,4 +150,4 @@ def load_data_with_retry(pipeline, data): ) # get data for a few famous players data = chess(chess_url="https://api.chess.com/pub/", max_players=MAX_PLAYERS) - load_data_with_retry(pipeline, data) \ No newline at end of file + load_data_with_retry(pipeline, data) diff --git a/docs/examples/connector_x_arrow/load_arrow.py b/docs/examples/connector_x_arrow/load_arrow.py index 24ba2acb0e..307e657514 100644 --- a/docs/examples/connector_x_arrow/load_arrow.py +++ b/docs/examples/connector_x_arrow/load_arrow.py @@ -3,6 +3,7 @@ import dlt from dlt.sources.credentials import ConnectionStringCredentials + def read_sql_x( conn_str: ConnectionStringCredentials = dlt.secrets.value, query: str = dlt.config.value, @@ -14,6 +15,7 @@ def read_sql_x( protocol="binary", ) + def genome_resource(): # create genome resource with merge on `upid` primary key genome = dlt.resource( diff --git a/docs/examples/google_sheets/google_sheets.py b/docs/examples/google_sheets/google_sheets.py index 8a93df9970..1ba330e4ca 100644 --- a/docs/examples/google_sheets/google_sheets.py +++ b/docs/examples/google_sheets/google_sheets.py @@ -9,6 +9,7 @@ ) from dlt.common.typing import DictStrAny, StrAny + def _initialize_sheets( credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials] ) -> Any: @@ -16,6 +17,7 @@ def _initialize_sheets( service = build("sheets", "v4", credentials=credentials.to_native_credentials()) return service + @dlt.source def google_spreadsheet( spreadsheet_id: str, @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: for name in sheet_names ] + if __name__ == "__main__": pipeline = dlt.pipeline(destination="duckdb") # see example.secrets.toml to where to put credentials @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: sheet_names=range_names, ) ) - print(info) \ No newline at end of file + print(info) diff --git a/docs/examples/incremental_loading/zendesk.py b/docs/examples/incremental_loading/zendesk.py index 4b8597886a..6113f98793 100644 --- a/docs/examples/incremental_loading/zendesk.py +++ b/docs/examples/incremental_loading/zendesk.py @@ -6,12 +6,11 @@ from dlt.common.typing import TAnyDateTime from dlt.sources.helpers.requests import client + @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -113,6 +112,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create dlt pipeline pipeline = dlt.pipeline( @@ -120,4 +120,4 @@ def get_pages( ) load_info = pipeline.run(zendesk_support()) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/nested_data/nested_data.py b/docs/examples/nested_data/nested_data.py index 3464448de6..7f85f0522e 100644 --- a/docs/examples/nested_data/nested_data.py +++ b/docs/examples/nested_data/nested_data.py @@ -13,6 +13,7 @@ CHUNK_SIZE = 10000 + # You can limit how deep dlt goes when generating child tables. # By default, the library will descend and generate child tables # for all nested lists, without a limit. @@ -81,6 +82,7 @@ def load_documents(self) -> Iterator[TDataItem]: while docs_slice := list(islice(cursor, CHUNK_SIZE)): yield map_nested_in_place(convert_mongo_objs, docs_slice) + def convert_mongo_objs(value: Any) -> Any: if isinstance(value, (ObjectId, Decimal128)): return str(value) diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 8f7833e7d7..e7f57853ed 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -4,6 +4,7 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from PyPDF2 import PdfReader + @dlt.resource(selected=False) def list_files(folder_path: str): folder_path = os.path.abspath(folder_path) @@ -15,6 +16,7 @@ def list_files(folder_path: str): "mtime": os.path.getmtime(file_path), } + @dlt.transformer(primary_key="page_id", write_disposition="merge") def pdf_to_text(file_item, separate_pages: bool = False): if not separate_pages: @@ -28,6 +30,7 @@ def pdf_to_text(file_item, separate_pages: bool = False): page_item["page_id"] = file_item["file_name"] + "_" + str(page_no) yield page_item + pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate") # this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf" @@ -51,4 +54,4 @@ def pdf_to_text(file_item, separate_pages: bool = False): client = weaviate.Client("http://localhost:8080") # get text of all the invoices in InvoiceText class we just created above -print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) \ No newline at end of file +print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) diff --git a/docs/examples/qdrant_zendesk/qdrant.py b/docs/examples/qdrant_zendesk/qdrant.py index 300d8dc6ad..bd0cbafc99 100644 --- a/docs/examples/qdrant_zendesk/qdrant.py +++ b/docs/examples/qdrant_zendesk/qdrant.py @@ -10,13 +10,12 @@ from dlt.common.configuration.inject import with_config + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -80,6 +79,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: return None return ensure_pendulum_datetime(value) + # modify dates to return datetime objects instead def _fix_date(ticket): ticket["updated_at"] = _parse_date_or_none(ticket["updated_at"]) @@ -87,6 +87,7 @@ def _fix_date(ticket): ticket["due_at"] = _parse_date_or_none(ticket["due_at"]) return ticket + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk def get_pages( url: str, @@ -127,6 +128,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create a pipeline with an appropriate name pipeline = dlt.pipeline( @@ -146,7 +148,6 @@ def get_pages( print(load_info) - # running the Qdrant client to connect to your Qdrant database @with_config(sections=("destination", "qdrant", "credentials")) diff --git a/docs/examples/transformers/pokemon.py b/docs/examples/transformers/pokemon.py index c17beff6a8..97b9a98b11 100644 --- a/docs/examples/transformers/pokemon.py +++ b/docs/examples/transformers/pokemon.py @@ -1,6 +1,7 @@ import dlt from dlt.sources.helpers import requests + @dlt.source(max_table_nesting=2) def source(pokemon_api_url: str): """""" @@ -46,6 +47,7 @@ def species(pokemon_details): return (pokemon_list | pokemon, pokemon_list | pokemon | species) + if __name__ == "__main__": # build duck db pipeline pipeline = dlt.pipeline( @@ -54,4 +56,4 @@ def species(pokemon_details): # the pokemon_list resource does not need to be loaded load_info = pipeline.run(source("https://pokeapi.co/api/v2/pokemon")) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index cd18454d7c..8614af4734 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -106,7 +106,9 @@ def assert_destination_state_loaded(pipeline: Pipeline) -> None: """Verify stored destination state matches the local pipeline state""" client: SqlJobClientBase with pipeline.destination_client() as client: # type: ignore[assignment] - destination_state = state_sync.load_state_from_destination(pipeline.pipeline_name, client) + destination_state = state_sync.load_pipeline_state_from_destination( + pipeline.pipeline_name, client + ) pipeline_state = dict(pipeline.state) del pipeline_state["_local"] assert pipeline_state == destination_state diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 381068f1e1..5cb59405a3 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -13,7 +13,11 @@ from dlt.pipeline.exceptions import SqlClientNotAvailable from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.state_sync import STATE_TABLE_COLUMNS, load_state_from_destination, state_resource +from dlt.pipeline.state_sync import ( + STATE_TABLE_COLUMNS, + load_pipeline_state_from_destination, + state_resource, +) from dlt.destinations.job_client_impl import SqlJobClientBase from tests.utils import TEST_STORAGE_ROOT @@ -54,14 +58,14 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - job_client: SqlJobClientBase with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] with pytest.raises(DestinationUndefinedEntity): - load_state_from_destination(p.pipeline_name, job_client) + load_pipeline_state_from_destination(p.pipeline_name, job_client) # sync the schema p.sync_schema() exists, _ = job_client.get_storage_table(schema.version_table_name) assert exists is True # dataset exists, still no table with pytest.raises(DestinationUndefinedEntity): - load_state_from_destination(p.pipeline_name, job_client) + load_pipeline_state_from_destination(p.pipeline_name, job_client) initial_state = p._get_state() # now add table to schema and sync initial_state["_local"]["_last_extracted_at"] = pendulum.now() @@ -84,14 +88,14 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - exists, _ = job_client.get_storage_table(schema.state_table_name) assert exists is True # table is there but no state - assert load_state_from_destination(p.pipeline_name, job_client) is None + assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None # extract state with p.managed_state(extract_state=True): pass # just run the existing extract p.normalize(loader_file_format=destination_config.file_format) p.load() - stored_state = load_state_from_destination(p.pipeline_name, job_client) + stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) local_state = p._get_state() local_state.pop("_local") assert stored_state == local_state @@ -101,7 +105,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - managed_state["sources"] = {"source": dict(JSON_TYPED_DICT_DECODED)} p.normalize(loader_file_format=destination_config.file_format) p.load() - stored_state = load_state_from_destination(p.pipeline_name, job_client) + stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) assert stored_state["sources"] == {"source": JSON_TYPED_DICT_DECODED} local_state = p._get_state() local_state.pop("_local") @@ -116,7 +120,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.normalize(loader_file_format=destination_config.file_format) info = p.load() assert len(info.loads_ids) == 0 - new_stored_state = load_state_from_destination(p.pipeline_name, job_client) + new_stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) # new state should not be stored assert new_stored_state == stored_state @@ -147,7 +151,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.normalize(loader_file_format=destination_config.file_format) info = p.load() assert len(info.loads_ids) == 1 - new_stored_state_2 = load_state_from_destination(p.pipeline_name, job_client) + new_stored_state_2 = load_pipeline_state_from_destination(p.pipeline_name, job_client) # the stored state changed to next version assert new_stored_state != new_stored_state_2 assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] @@ -405,7 +409,7 @@ def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = Fa job_client: SqlJobClientBase with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] # state without completed load id is not visible - state = load_state_from_destination(pipeline_name, job_client) + state = load_pipeline_state_from_destination(pipeline_name, job_client) assert state is None diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index a82b7baef6..407a156d6f 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -235,7 +235,7 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: assert_items_in_range(calls["items"], 0, 100) assert_items_in_range(calls["items2"], 0, 100) # destination state should be cleared after load - assert p.get_load_package_state(load_id) == {} + assert p.get_load_package_state(load_id)["destinations"] == {} # provoke errors calls = {} @@ -246,9 +246,10 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: # we should have data for one load id saved here load_id = p.list_normalized_load_packages()[0] - load_package_state = p.get_load_package_state(load_id) + load_package_state = p.get_load_package_state(load_id)["destinations"] assert len(load_package_state) == 1 + # get saved indexes mapped to table (this test will only work for one job per table) values = {k.split(".")[0]: v for k, v in list(load_package_state.values())[0].items()} @@ -275,7 +276,7 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: calls = {} p.load() # state should be cleared again - load_package_state = p.get_load_package_state(load_id) + load_package_state = p.get_load_package_state(load_id)["destinations"] assert load_package_state == {} # both calls combined should have every item called just once diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 7e99027e08..921a81eaa5 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1333,11 +1333,11 @@ def test_resource_state_name_not_normalized() -> None: pipeline.load() # get state from destination - from dlt.pipeline.state_sync import load_state_from_destination + from dlt.pipeline.state_sync import load_pipeline_state_from_destination client: WithStateSync with pipeline.destination_client() as client: # type: ignore[assignment] - state = load_state_from_destination(pipeline.pipeline_name, client) + state = load_pipeline_state_from_destination(pipeline.pipeline_name, client) assert "airtable_emojis" in state["sources"] assert state["sources"]["airtable_emojis"]["resources"] == {"🦚Peacock": {"🦚🦚🦚": "🦚"}} diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index ee788367e1..de0a98d9b1 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -14,7 +14,11 @@ from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.state_sync import generate_version_hash, migrate_state, STATE_ENGINE_VERSION +from dlt.pipeline.state_sync import ( + generate_state_version_hash, + migrate_pipeline_state, + PIPELINE_STATE_ENGINE_VERSION, +) from tests.utils import test_storage from tests.pipeline.utils import json_case_path, load_json_case @@ -482,21 +486,21 @@ def transform(item): ) -def test_migrate_state(test_storage: FileStorage) -> None: +def test_migrate_pipeline_state(test_storage: FileStorage) -> None: # test generation of version hash on migration to v3 state_v1 = load_json_case("state/state.v1") - state = migrate_state("test_pipeline", state_v1, state_v1["_state_engine_version"], 3) + state = migrate_pipeline_state("test_pipeline", state_v1, state_v1["_state_engine_version"], 3) assert state["_state_engine_version"] == 3 assert "_local" in state assert "_version_hash" in state - assert state["_version_hash"] == generate_version_hash(state) + assert state["_version_hash"] == generate_state_version_hash(state) # full migration state_v1 = load_json_case("state/state.v1") - state = migrate_state( - "test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION + state = migrate_pipeline_state( + "test_pipeline", state_v1, state_v1["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION ) - assert state["_state_engine_version"] == STATE_ENGINE_VERSION + assert state["_state_engine_version"] == PIPELINE_STATE_ENGINE_VERSION # check destination migration assert state["destination_name"] == "postgres" @@ -505,12 +509,15 @@ def test_migrate_state(test_storage: FileStorage) -> None: with pytest.raises(PipelineStateEngineNoUpgradePathException) as py_ex: state_v1 = load_json_case("state/state.v1") - migrate_state( - "test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION + 1 + migrate_pipeline_state( + "test_pipeline", + state_v1, + state_v1["_state_engine_version"], + PIPELINE_STATE_ENGINE_VERSION + 1, ) assert py_ex.value.init_engine == state_v1["_state_engine_version"] - assert py_ex.value.from_engine == STATE_ENGINE_VERSION - assert py_ex.value.to_engine == STATE_ENGINE_VERSION + 1 + assert py_ex.value.from_engine == PIPELINE_STATE_ENGINE_VERSION + assert py_ex.value.to_engine == PIPELINE_STATE_ENGINE_VERSION + 1 # also test pipeline init where state is old test_storage.create_folder("debug_pipeline") @@ -522,7 +529,7 @@ def test_migrate_state(test_storage: FileStorage) -> None: assert p.dataset_name == "debug_pipeline_data" assert p.default_schema_name == "example_source" state = p.state - assert state["_version_hash"] == generate_version_hash(state) + assert state["_version_hash"] == generate_state_version_hash(state) # specifically check destination v3 to v4 migration state_v3 = { @@ -530,8 +537,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: "staging": "dlt.destinations.filesystem", "_state_engine_version": 3, } - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert state_v3["destination_name"] == "redshift" assert state_v3["destination_type"] == "dlt.destinations.redshift" @@ -544,8 +551,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: "destination": "dlt.destinations.redshift", "_state_engine_version": 3, } - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert state_v3["destination_name"] == "redshift" assert state_v3["destination_type"] == "dlt.destinations.redshift" @@ -554,8 +561,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: assert "staging_type" not in state_v3 state_v3 = {"destination": None, "staging": None, "_state_engine_version": 3} - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert "destination_name" not in state_v3 assert "destination_type" not in state_v3 @@ -563,8 +570,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: assert "staging_type" not in state_v3 state_v3 = {"_state_engine_version": 2} - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert "destination_name" not in state_v3 assert "destination_type" not in state_v3 From daae33e82d9dbaa102c3d01c1f9525f3a7cf0156 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 6 Feb 2024 16:03:54 +0100 Subject: [PATCH 17/35] restrict access of destinations to load package state in accessor functions --- dlt/common/pipeline.py | 37 +++++++++++++++++++++++++++--- dlt/destinations/impl/sink/sink.py | 16 ++++++------- dlt/load/load.py | 15 +++--------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 0d53a4f3a8..89bce6930b 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -3,6 +3,8 @@ import datetime # noqa: 251 import humanize import contextlib +import threading + from typing import ( Any, Callable, @@ -40,6 +42,8 @@ from dlt.common.schema.typing import TColumnNames, TColumnSchema, TWriteDisposition, TSchemaContract from dlt.common.source import get_current_pipe_name from dlt.common.storages.load_storage import LoadPackageInfo +from dlt.common.storages.load_package import PackageStorage + from dlt.common.time import ensure_pendulum_datetime, precise_time from dlt.common.typing import DictStrAny, REPattern, StrAny, SupportsHumanize from dlt.common.jsonpath import delete_matches, TAnyJsonPath @@ -600,14 +604,23 @@ def __init__(self, state: TPipelineState = None) -> None: ... @configspec class LoadPackageStateInjectableContext(ContainerInjectableContext): - state: TLoadPackageState - commit: Optional[Callable[[], None]] + storage: PackageStorage + load_id: str + destination_name: Optional[str] can_create_default: ClassVar[bool] = False + def commit(self) -> None: + with self.state_save_lock: + self.storage.save_load_package_state(self.load_id, self.state) + + def on_resolved(self) -> None: + self.state_save_lock = threading.Lock() + self.state = self.storage.get_load_package_state(self.load_id) + if TYPE_CHECKING: def __init__( - self, state: TLoadPackageState = None, commit: Optional[Callable[[], None]] = None + self, load_id: str, storage: PackageStorage, destination_name: Optional[str] ) -> None: ... @@ -692,6 +705,7 @@ def source_state() -> DictStrAny: def load_package_state() -> TLoadPackageState: + """Get full load package state present in current context. Across all threads this will be the same in memory dict.""" container = Container() # get injected state if present. injected load package state is typically "managed" so changes will be persisted # if you need to save the load package state during a load, you need to call commit_load_package_state @@ -703,6 +717,7 @@ def load_package_state() -> TLoadPackageState: def commit_load_package_state() -> None: + """Commit load package state present in current context. This is thread safe.""" container = Container() try: state_ctx = container[LoadPackageStateInjectableContext] @@ -711,6 +726,22 @@ def commit_load_package_state() -> None: state_ctx.commit() +def load_package_destination_state() -> DictStrAny: + """Get segment of load package state that is specific to the current destination.""" + lp_state = load_package_state() + destination_name = Container()[LoadPackageStateInjectableContext].destination_name + return lp_state.setdefault("destinations", {}).setdefault(destination_name, {}) + + +def clear_loadpackage_destination_state(commit: bool = True) -> None: + """Clear segment of load package state that is specific to the current destination. Optionally commit to load package.""" + lp_state = load_package_state() + destination_name = Container()[LoadPackageStateInjectableContext].destination_name + lp_state["destinations"].pop(destination_name, None) + if commit: + commit_load_package_state() + + def _delete_source_state_keys( key: TAnyJsonPath, source_state_: Optional[DictStrAny] = None, / ) -> None: diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 3b3396d3f6..31981e322c 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -5,7 +5,11 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems from dlt.common import json -from dlt.common.pipeline import load_package_state, commit_load_package_state +from dlt.common.pipeline import ( + load_package_destination_state, + commit_load_package_state, + clear_loadpackage_destination_state, +) from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TTableSchema @@ -145,11 +149,7 @@ def update_stored_schema( def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: # save our state in destination name scope - load_state = ( - load_package_state() - .setdefault("destinations", {}) - .setdefault(self.config.destination_name, {}) - ) + load_state = load_package_destination_state() if file_path.endswith("parquet"): return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state) if file_path.endswith("jsonl"): @@ -161,9 +161,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: def complete_load(self, load_id: str) -> None: # pop all state for this load on success - state = load_package_state() - state["destinations"].pop(self.config.destination_name, None) - commit_load_package_state() + clear_loadpackage_destination_state() def __enter__(self) -> "SinkClient": return self diff --git a/dlt/load/load.py b/dlt/load/load.py index 024e42c745..072b4f1d44 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -566,23 +566,14 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: schema = self.load_storage.normalized_packages.load_schema(load_id) logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") - # prepare load package state context - load_package_state = self.load_storage.normalized_packages.get_load_package_state(load_id) - state_save_lock = threading.Lock() - - def commit_load_package_state() -> None: - with state_save_lock: - self.load_storage.normalized_packages.save_load_package_state( - load_id, load_package_state - ) - container = Container() # get top load id and mark as being processed with self.collector(f"Load {schema.name} in {load_id}"): with container.injectable_context( LoadPackageStateInjectableContext( - state=load_package_state, - commit=commit_load_package_state, + storage=self.load_storage.normalized_packages, + load_id=load_id, + destination_name=self.initial_client_config.destination_name, ) ): # the same load id may be processed across multiple runs From 644e6f3c5c2ebc62c008c8987e7c902abd3f6f8e Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 6 Feb 2024 16:12:50 +0100 Subject: [PATCH 18/35] fix tests --- dlt/common/storages/normalize_storage.py | 4 +++- tests/pipeline/test_pipeline_state.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 8a247c2021..2b90b7c088 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -51,7 +51,9 @@ def list_files_to_normalize_sorted(self) -> Sequence[str]: [ file for file in files - if not file.endswith(PackageStorage.SCHEMA_FILE_NAME) and os.path.isfile(file) + if not file.endswith(PackageStorage.SCHEMA_FILE_NAME) + and os.path.isfile(file) + and not file.endswith(PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME) ] ) diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index de0a98d9b1..f0bcda2717 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -15,7 +15,7 @@ from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.state_sync import ( - generate_state_version_hash, + generate_pipeline_state_version_hash, migrate_pipeline_state, PIPELINE_STATE_ENGINE_VERSION, ) @@ -493,7 +493,7 @@ def test_migrate_pipeline_state(test_storage: FileStorage) -> None: assert state["_state_engine_version"] == 3 assert "_local" in state assert "_version_hash" in state - assert state["_version_hash"] == generate_state_version_hash(state) + assert state["_version_hash"] == generate_pipeline_state_version_hash(state) # full migration state_v1 = load_json_case("state/state.v1") @@ -529,7 +529,7 @@ def test_migrate_pipeline_state(test_storage: FileStorage) -> None: assert p.dataset_name == "debug_pipeline_data" assert p.default_schema_name == "example_source" state = p.state - assert state["_version_hash"] == generate_state_version_hash(state) + assert state["_version_hash"] == generate_pipeline_state_version_hash(state) # specifically check destination v3 to v4 migration state_v3 = { From 9930ad69f807853a83f93dbffc5ac550b2902e15 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 6 Feb 2024 17:23:19 +0100 Subject: [PATCH 19/35] add tests for state and new injectable context --- dlt/common/storages/load_package.py | 2 + tests/common/storages/test_load_package.py | 81 ++++++++++++++++++++++ tests/common/test_versioned_state.py | 43 ++++++++++++ tests/load/sink/test_sink.py | 2 + tests/utils.py | 2 - 5 files changed, 128 insertions(+), 2 deletions(-) create mode 100644 tests/common/test_versioned_state.py diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index b83eae0dc4..483011a2e7 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -63,6 +63,8 @@ def bump_loadpackage_state_version_if_modified(state: TLoadPackageState) -> Tupl def migrate_loadpackage_state( state: DictStrAny, from_engine: int, to_engine: int ) -> TLoadPackageState: + # TODO: if you start adding new versions, we need proper tests for these migrations! + # NOTE: do not touch destinations state, it is not versioned if from_engine == to_engine: return cast(TLoadPackageState, state) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index f671ddcf32..a8d333d431 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -9,6 +9,15 @@ from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage from tests.utils import autouse_test_storage +from dlt.common.pendulum import pendulum +from dlt.common.configuration.container import Container +from dlt.common.pipeline import ( + LoadPackageStateInjectableContext, + load_package_destination_state, + load_package_state, + commit_load_package_state, + clear_loadpackage_destination_state, +) def test_is_partially_loaded(load_storage: LoadStorage) -> None: @@ -57,6 +66,78 @@ def test_save_load_schema(load_storage: LoadStorage) -> None: assert schema.stored_version == schema_copy.stored_version +def test_create_and_update_loadpackage_state(load_storage: LoadStorage) -> None: + load_storage.new_packages.create_package("copy") + state = load_storage.new_packages.get_load_package_state("copy") + assert state["_state_version"] == 0 + assert state["_version_hash"] is not None + assert state["created"] is not None + old_state = state.copy() + + state["new_key"] = "new_value" # type: ignore + load_storage.new_packages.save_load_package_state("copy", state) + + state = load_storage.new_packages.get_load_package_state("copy") + assert state["new_key"] == "new_value" # type: ignore + assert state["_state_version"] == 1 + assert state["_version_hash"] != old_state["_version_hash"] + # created timestamp should be conserved + assert state["created"] == old_state["created"] + + # check timestamp + time = pendulum.from_timestamp(state["created"]) + now = pendulum.now() + (now - time).in_seconds() < 2 + + +def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None: + load_storage.new_packages.create_package("copy") + + container = Container() + with container.injectable_context( + LoadPackageStateInjectableContext( + storage=load_storage.new_packages, + load_id="copy", + destination_name="some_destination_name", + ) + ): + # test general load package state + injected_state = load_package_state() + assert injected_state["_state_version"] == 0 + injected_state["new_key"] = "new_value" # type: ignore + + # not persisted yet + assert load_storage.new_packages.get_load_package_state("copy").get("new_key") is None + # commit + commit_load_package_state() + + # now it should be persisted + assert ( + load_storage.new_packages.get_load_package_state("copy").get("new_key") == "new_value" + ) + assert load_storage.new_packages.get_load_package_state("copy").get("_state_version") == 1 + + # check that second injection is the same as first + second_injected_instance = load_package_state() + assert second_injected_instance == injected_state + + # check scoped destination states + assert load_storage.new_packages.get_load_package_state("copy").get("destinations") is None + destination_state = load_package_destination_state() + destination_state["new_key"] = "new_value" + commit_load_package_state() + assert load_storage.new_packages.get_load_package_state("copy").get("destinations") == { + "some_destination_name": {"new_key": "new_value"} + } + + # this also shows up on the previously injected state + assert injected_state["destinations"]["some_destination_name"]["new_key"] == "new_value" + + # clear destination state + clear_loadpackage_destination_state() + assert load_storage.new_packages.get_load_package_state("copy").get("destinations") == {} + + def test_job_elapsed_time_seconds(load_storage: LoadStorage) -> None: load_id, fn = start_loading_file(load_storage, "test file") # type: ignore[arg-type] fp = load_storage.normalized_packages.storage.make_full_path( diff --git a/tests/common/test_versioned_state.py b/tests/common/test_versioned_state.py new file mode 100644 index 0000000000..4a84a258a5 --- /dev/null +++ b/tests/common/test_versioned_state.py @@ -0,0 +1,43 @@ +from dlt.common.versioned_state import ( + generate_state_version_hash, + bump_state_version_if_modified, + default_versioned_state, +) + + +def test_versioned_state() -> None: + state = default_versioned_state() + assert state["_state_version"] == 0 + assert state["_state_engine_version"] == 1 + + # first hash generation does not change version, attrs are not modified + version, hash, previous_hash = bump_state_version_if_modified(state) + assert version == 0 + assert hash is not None + assert previous_hash is None + assert state["_version_hash"] == hash + + # change attr, but exclude while generating + state["foo"] = "bar" # type: ignore + version, hash, previous_hash = bump_state_version_if_modified(state, exclude_attrs=["foo"]) + assert version == 0 + assert hash == previous_hash + + # now don't exclude (remember old hash to compare return vars) + old_hash = state["_version_hash"] + version, hash, previous_hash = bump_state_version_if_modified(state) + assert version == 1 + assert hash != previous_hash + assert old_hash != hash + assert previous_hash == old_hash + + # messing with state engine version will not change hash + state["_state_engine_version"] = 5 + version, hash, previous_hash = bump_state_version_if_modified(state) + assert version == 1 + assert hash == previous_hash + + # make sure state object is not modified while bumping with no effect + old_state = state.copy() + version, hash, previous_hash = bump_state_version_if_modified(state) + assert old_state == state diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 407a156d6f..f1979d52bd 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -248,7 +248,9 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: load_id = p.list_normalized_load_packages()[0] load_package_state = p.get_load_package_state(load_id)["destinations"] + # should be only one destination: sink assert len(load_package_state) == 1 + assert "sink" in load_package_state # get saved indexes mapped to table (this test will only work for one job per table) values = {k.split(".")[0]: v for k, v in list(load_package_state.values())[0].items()} diff --git a/tests/utils.py b/tests/utils.py index f368957155..c86ae92a2b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -61,8 +61,6 @@ # filter out active destinations for current tests ACTIVE_DESTINATIONS = set(dlt.config.get("ACTIVE_DESTINATIONS", list) or IMPLEMENTED_DESTINATIONS) -ACTIVE_DESTINATIONS = {"sink"} - ACTIVE_SQL_DESTINATIONS = SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) ACTIVE_NON_SQL_DESTINATIONS = NON_SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) From 678d18734d4443f0f9ad627782b4a8ae21fd6c26 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 6 Feb 2024 17:45:34 +0100 Subject: [PATCH 20/35] fix linter --- tests/common/test_versioned_state.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/common/test_versioned_state.py b/tests/common/test_versioned_state.py index 4a84a258a5..e1f31a8a92 100644 --- a/tests/common/test_versioned_state.py +++ b/tests/common/test_versioned_state.py @@ -10,34 +10,34 @@ def test_versioned_state() -> None: assert state["_state_version"] == 0 assert state["_state_engine_version"] == 1 - # first hash generation does not change version, attrs are not modified - version, hash, previous_hash = bump_state_version_if_modified(state) + # first hash_ generation does not change version, attrs are not modified + version, hash_, previous_hash = bump_state_version_if_modified(state) assert version == 0 - assert hash is not None + assert hash_ is not None assert previous_hash is None - assert state["_version_hash"] == hash + assert state["_version_hash"] == hash_ # change attr, but exclude while generating state["foo"] = "bar" # type: ignore - version, hash, previous_hash = bump_state_version_if_modified(state, exclude_attrs=["foo"]) + version, hash_, previous_hash = bump_state_version_if_modified(state, exclude_attrs=["foo"]) assert version == 0 - assert hash == previous_hash + assert hash_ == previous_hash - # now don't exclude (remember old hash to compare return vars) + # now don't exclude (remember old hash_ to compare return vars) old_hash = state["_version_hash"] - version, hash, previous_hash = bump_state_version_if_modified(state) + version, hash_, previous_hash = bump_state_version_if_modified(state) assert version == 1 - assert hash != previous_hash - assert old_hash != hash + assert hash_ != previous_hash + assert old_hash != hash_ assert previous_hash == old_hash - # messing with state engine version will not change hash + # messing with state engine version will not change hash_ state["_state_engine_version"] = 5 - version, hash, previous_hash = bump_state_version_if_modified(state) + version, hash_, previous_hash = bump_state_version_if_modified(state) assert version == 1 - assert hash == previous_hash + assert hash_ == previous_hash # make sure state object is not modified while bumping with no effect old_state = state.copy() - version, hash, previous_hash = bump_state_version_if_modified(state) + version, hash_, previous_hash = bump_state_version_if_modified(state) assert old_state == state From 376832dd8360d113d2485120fed3c4a642d248b1 Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 7 Feb 2024 09:21:31 +0100 Subject: [PATCH 21/35] fix linter error --- tests/common/storages/test_load_package.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index a8d333d431..651097e619 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -87,7 +87,7 @@ def test_create_and_update_loadpackage_state(load_storage: LoadStorage) -> None: # check timestamp time = pendulum.from_timestamp(state["created"]) now = pendulum.now() - (now - time).in_seconds() < 2 + assert (now - time).in_seconds() < 2 def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None: From 79fce9ee2754eefa7af61bd1c3f2726786a6701a Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 7 Feb 2024 18:28:34 +0100 Subject: [PATCH 22/35] some pr fixes --- dlt/common/pipeline.py | 61 -------------------- dlt/common/storages/load_package.py | 66 +++++++++++++++++++++- dlt/destinations/impl/sink/sink.py | 12 ++-- dlt/load/load.py | 3 +- dlt/pipeline/current.py | 6 ++ tests/common/storages/test_load_package.py | 31 +++++----- tests/load/sink/test_sink.py | 23 ++++---- 7 files changed, 107 insertions(+), 95 deletions(-) diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 89bce6930b..85cb5613bf 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -3,7 +3,6 @@ import datetime # noqa: 251 import humanize import contextlib -import threading from typing import ( Any, @@ -602,28 +601,6 @@ class StateInjectableContext(ContainerInjectableContext): def __init__(self, state: TPipelineState = None) -> None: ... -@configspec -class LoadPackageStateInjectableContext(ContainerInjectableContext): - storage: PackageStorage - load_id: str - destination_name: Optional[str] - can_create_default: ClassVar[bool] = False - - def commit(self) -> None: - with self.state_save_lock: - self.storage.save_load_package_state(self.load_id, self.state) - - def on_resolved(self) -> None: - self.state_save_lock = threading.Lock() - self.state = self.storage.get_load_package_state(self.load_id) - - if TYPE_CHECKING: - - def __init__( - self, load_id: str, storage: PackageStorage, destination_name: Optional[str] - ) -> None: ... - - def pipeline_state( container: Container, initial_default: TPipelineState = None ) -> Tuple[TPipelineState, bool]: @@ -704,44 +681,6 @@ def source_state() -> DictStrAny: _last_full_state: TPipelineState = None -def load_package_state() -> TLoadPackageState: - """Get full load package state present in current context. Across all threads this will be the same in memory dict.""" - container = Container() - # get injected state if present. injected load package state is typically "managed" so changes will be persisted - # if you need to save the load package state during a load, you need to call commit_load_package_state - try: - state_ctx = container[LoadPackageStateInjectableContext] - except ContextDefaultCannotBeCreated: - raise Exception("Load package state not available") - return state_ctx.state - - -def commit_load_package_state() -> None: - """Commit load package state present in current context. This is thread safe.""" - container = Container() - try: - state_ctx = container[LoadPackageStateInjectableContext] - except ContextDefaultCannotBeCreated: - raise Exception("Load package state not available") - state_ctx.commit() - - -def load_package_destination_state() -> DictStrAny: - """Get segment of load package state that is specific to the current destination.""" - lp_state = load_package_state() - destination_name = Container()[LoadPackageStateInjectableContext].destination_name - return lp_state.setdefault("destinations", {}).setdefault(destination_name, {}) - - -def clear_loadpackage_destination_state(commit: bool = True) -> None: - """Clear segment of load package state that is specific to the current destination. Optionally commit to load package.""" - lp_state = load_package_state() - destination_name = Container()[LoadPackageStateInjectableContext].destination_name - lp_state["destinations"].pop(destination_name, None) - if commit: - commit_load_package_state() - - def _delete_source_state_keys( key: TAnyJsonPath, source_state_: Optional[DictStrAny] = None, / ) -> None: diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 483011a2e7..d6da5c9eaf 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -1,6 +1,8 @@ import contextlib import os from copy import deepcopy +import threading + import datetime # noqa: 251 import humanize from pathlib import Path @@ -18,9 +20,16 @@ cast, Any, Tuple, + TYPE_CHECKING, ) from dlt.common import pendulum, json + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import ContainerInjectableContext +from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated +from dlt.common.configuration.container import Container + from dlt.common.data_writers import DataWriter, new_file_id from dlt.common.destination import TLoaderFileFormat from dlt.common.exceptions import TerminalValueError @@ -44,7 +53,7 @@ class TLoadPackageState(TVersionedState, total=False): """Timestamp when the loadpackage was created""" """A section of state that does not participate in change merging and version control""" - destinations: NotRequired[Dict[str, Dict[str, Any]]] + destination_state: NotRequired[Dict[str, Any]] """private space for destinations to store state relevant only to the load package""" @@ -593,3 +602,58 @@ def is_package_partially_loaded(package_info: LoadPackageInfo) -> bool: @staticmethod def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: return (now_ts or pendulum.now().timestamp()) - os.path.getmtime(file_path) + + +@configspec +class LoadPackageStateInjectableContext(ContainerInjectableContext): + storage: PackageStorage + load_id: str + can_create_default: ClassVar[bool] = False + + def commit(self) -> None: + with self.state_save_lock: + self.storage.save_load_package_state(self.load_id, self.state) + + def on_resolved(self) -> None: + self.state_save_lock = threading.Lock() + self.state = self.storage.get_load_package_state(self.load_id) + + if TYPE_CHECKING: + + def __init__(self, load_id: str, storage: PackageStorage) -> None: ... + + +def load_package_state() -> TLoadPackageState: + """Get full load package state present in current context. Across all threads this will be the same in memory dict.""" + container = Container() + # get injected state if present. injected load package state is typically "managed" so changes will be persisted + # if you need to save the load package state during a load, you need to call commit_load_package_state + try: + state_ctx = container[LoadPackageStateInjectableContext] + except ContextDefaultCannotBeCreated: + raise Exception("Load package state not available") + return state_ctx.state + + +def commit_load_package_state() -> None: + """Commit load package state present in current context. This is thread safe.""" + container = Container() + try: + state_ctx = container[LoadPackageStateInjectableContext] + except ContextDefaultCannotBeCreated: + raise Exception("Load package state not available") + state_ctx.commit() + + +def destination_state() -> DictStrAny: + """Get segment of load package state that is specific to the current destination.""" + lp_state = load_package_state() + return lp_state.setdefault("destination_state", {}) + + +def clear_destination_state(commit: bool = True) -> None: + """Clear segment of load package state that is specific to the current destination. Optionally commit to load package.""" + lp_state = load_package_state() + lp_state.pop("destination_state", None) + if commit: + commit_load_package_state() diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 31981e322c..3f577e0dca 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -5,10 +5,10 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems from dlt.common import json -from dlt.common.pipeline import ( - load_package_destination_state, +from dlt.pipeline.current import ( + destination_state, commit_load_package_state, - clear_loadpackage_destination_state, + clear_destination_state, ) from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -149,7 +149,7 @@ def update_stored_schema( def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: # save our state in destination name scope - load_state = load_package_destination_state() + load_state = destination_state() if file_path.endswith("parquet"): return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state) if file_path.endswith("jsonl"): @@ -159,9 +159,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def complete_load(self, load_id: str) -> None: - # pop all state for this load on success - clear_loadpackage_destination_state() + def complete_load(self, load_id: str) -> None: ... def __enter__(self) -> "SinkClient": return self diff --git a/dlt/load/load.py b/dlt/load/load.py index 072b4f1d44..4666aa0851 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -14,10 +14,10 @@ LoadMetrics, SupportsPipeline, WithStepInfo, - LoadPackageStateInjectableContext, ) from dlt.common.schema.utils import get_child_tables, get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState +from dlt.common.storages.load_package import LoadPackageStateInjectableContext from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.runtime.logger import pretty_format_exception @@ -573,7 +573,6 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: LoadPackageStateInjectableContext( storage=self.load_storage.normalized_packages, load_id=load_id, - destination_name=self.initial_client_config.destination_name, ) ): # the same load id may be processed across multiple runs diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index f915a30932..601c8faee3 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -3,6 +3,12 @@ from dlt.common.pipeline import source_state as _state, resource_state from dlt.pipeline import pipeline as _pipeline from dlt.extract.decorators import get_source_schema +from dlt.common.storages.load_package import ( + load_package_state, + commit_load_package_state, + destination_state, + clear_destination_state, +) pipeline = _pipeline """Alias for dlt.pipeline""" diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index 651097e619..3baf3e2fe1 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -11,12 +11,12 @@ from tests.utils import autouse_test_storage from dlt.common.pendulum import pendulum from dlt.common.configuration.container import Container -from dlt.common.pipeline import ( +from dlt.common.storages.load_package import ( LoadPackageStateInjectableContext, - load_package_destination_state, + destination_state, load_package_state, commit_load_package_state, - clear_loadpackage_destination_state, + clear_destination_state, ) @@ -98,7 +98,6 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None LoadPackageStateInjectableContext( storage=load_storage.new_packages, load_id="copy", - destination_name="some_destination_name", ) ): # test general load package state @@ -122,20 +121,26 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None assert second_injected_instance == injected_state # check scoped destination states - assert load_storage.new_packages.get_load_package_state("copy").get("destinations") is None - destination_state = load_package_destination_state() - destination_state["new_key"] = "new_value" + assert ( + load_storage.new_packages.get_load_package_state("copy").get("destination_state") + is None + ) + dstate = destination_state() + dstate["new_key"] = "new_value" commit_load_package_state() - assert load_storage.new_packages.get_load_package_state("copy").get("destinations") == { - "some_destination_name": {"new_key": "new_value"} - } + assert load_storage.new_packages.get_load_package_state("copy").get( + "destination_state" + ) == {"new_key": "new_value"} # this also shows up on the previously injected state - assert injected_state["destinations"]["some_destination_name"]["new_key"] == "new_value" + assert injected_state["destination_state"]["new_key"] == "new_value" # clear destination state - clear_loadpackage_destination_state() - assert load_storage.new_packages.get_load_package_state("copy").get("destinations") == {} + clear_destination_state() + assert ( + load_storage.new_packages.get_load_package_state("copy").get("destination_state") + == None + ) def test_job_elapsed_time_seconds(load_storage: LoadStorage) -> None: diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index f1979d52bd..080c3e3c48 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -234,8 +234,11 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: load_id = p.run([items(), items2()]).loads_ids[0] assert_items_in_range(calls["items"], 0, 100) assert_items_in_range(calls["items2"], 0, 100) - # destination state should be cleared after load - assert p.get_load_package_state(load_id)["destinations"] == {} + + # destination state should have all items + destination_state = p.get_load_package_state(load_id)["destination_state"] + values = {k.split(".")[0]: v for k, v in destination_state.items()} + assert values == {"_dlt_pipeline_state": 1, "items": 100, "items2": 100} # provoke errors calls = {} @@ -246,14 +249,10 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: # we should have data for one load id saved here load_id = p.list_normalized_load_packages()[0] - load_package_state = p.get_load_package_state(load_id)["destinations"] - - # should be only one destination: sink - assert len(load_package_state) == 1 - assert "sink" in load_package_state + destination_state = p.get_load_package_state(load_id)["destination_state"] # get saved indexes mapped to table (this test will only work for one job per table) - values = {k.split(".")[0]: v for k, v in list(load_package_state.values())[0].items()} + values = {k.split(".")[0]: v for k, v in destination_state.items()} # partly loaded, pointers in state should be right if batch_size == 1: @@ -277,9 +276,11 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: provoke_error = {} calls = {} p.load() - # state should be cleared again - load_package_state = p.get_load_package_state(load_id)["destinations"] - assert load_package_state == {} + + # destination state should have all items + destination_state = p.get_load_package_state(load_id)["destination_state"] + values = {k.split(".")[0]: v for k, v in destination_state.items()} + assert values == {"_dlt_pipeline_state": 1, "items": 100, "items2": 100} # both calls combined should have every item called just once assert_items_in_range(calls["items"] + first_calls["items"], 0, 100) From 105569a53c9daf3d2430799193138446059f11f8 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 8 Feb 2024 12:04:45 +0100 Subject: [PATCH 23/35] more pr fixes --- dlt/__init__.py | 4 +-- dlt/common/data_types/type_helpers.py | 2 +- dlt/common/destination/capabilities.py | 2 +- dlt/common/storages/load_package.py | 6 ++-- dlt/common/storages/load_storage.py | 7 +++- dlt/destinations/impl/sink/__init__.py | 3 +- dlt/destinations/impl/sink/configuration.py | 4 +-- dlt/destinations/impl/sink/factory.py | 2 +- dlt/destinations/impl/sink/sink.py | 30 +++++------------ dlt/pipeline/pipeline.py | 1 + .../docs/dlt-ecosystem/destinations/sink.md | 25 +++++++------- tests/common/storages/test_load_package.py | 8 ++--- tests/load/sink/test_sink.py | 33 ++++++++++--------- 13 files changed, 62 insertions(+), 65 deletions(-) diff --git a/dlt/__init__.py b/dlt/__init__.py index 6b567b3398..c40416ba73 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -29,7 +29,7 @@ from dlt import sources from dlt.extract.decorators import source, resource, transformer, defer -from dlt.destinations.decorators import sink +from dlt.destinations.decorators import sink as destination from dlt.pipeline import ( pipeline as _pipeline, @@ -64,7 +64,7 @@ "resource", "transformer", "defer", - "sink", + "destination", "pipeline", "run", "attach", diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index 800fa8a680..29a8084a28 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -111,7 +111,7 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: try: return json.loads(value) except Exception: - raise ValueError("Cannot load text as json for complex type") + raise ValueError(value) if to_type == "text": if from_type == "complex": diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index a78a31fdf3..0f2500c2cd 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -19,7 +19,7 @@ ] ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) # file formats used internally by dlt -INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"puae-jsonl", "sql", "reference", "arrow"} +INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"sql", "reference", "arrow"} # file formats that may be chosen by the user EXTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = ( set(get_args(TLoaderFileFormat)) - INTERNAL_LOADER_FILE_FORMATS diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index d6da5c9eaf..118654445e 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -49,7 +49,7 @@ class TLoadPackageState(TVersionedState, total=False): - created: int + created_at: str """Timestamp when the loadpackage was created""" """A section of state that does not participate in change merging and version control""" @@ -403,8 +403,8 @@ def create_package(self, load_id: str) -> None: self.storage.create_folder(os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER)) # ensure created timestamp is set in state when load package is created state = self.get_load_package_state(load_id) - if not state.get("created"): - state["created"] = pendulum.now().timestamp() + if not state.get("created_at"): + state["created_at"] = pendulum.now().to_iso8601_string() self.save_load_package_state(load_id, state) def complete_loading_package(self, load_id: str, load_state: TLoadPackageStatus) -> str: diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 080a30b5a6..ffd55e7f29 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -40,6 +40,11 @@ def __init__( supported_file_formats: Iterable[TLoaderFileFormat], config: LoadStorageConfiguration = config.value, ) -> None: + # puae-jsonl jobs have the extension .jsonl, so cater for this here + if supported_file_formats and "puae-jsonl" in supported_file_formats: + supported_file_formats = list(supported_file_formats) + supported_file_formats.append("jsonl") + if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset(supported_file_formats): raise TerminalValueError(supported_file_formats) if preferred_file_format and preferred_file_format not in supported_file_formats: @@ -81,7 +86,7 @@ def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> def list_new_jobs(self, load_id: str) -> Sequence[str]: """Lists all jobs in new jobs folder of normalized package storage and checks if file formats are supported""" new_jobs = self.normalized_packages.list_new_jobs(load_id) - # # make sure all jobs have supported writers + # make sure all jobs have supported writers wrong_job = next( ( j diff --git a/dlt/destinations/impl/sink/__init__.py b/dlt/destinations/impl/sink/__init__.py index 2902fb8b03..72e93622c2 100644 --- a/dlt/destinations/impl/sink/__init__.py +++ b/dlt/destinations/impl/sink/__init__.py @@ -3,9 +3,10 @@ def capabilities( - preferred_loader_file_format: TLoaderFileFormat = "parquet", + preferred_loader_file_format: TLoaderFileFormat = "puae-jsonl", ) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) + caps.supported_loader_file_formats = ["puae-jsonl", "parquet"] caps.supports_ddl_transactions = False caps.supports_transactions = False return caps diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index bb9caab294..9a96aea98d 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -62,7 +62,7 @@ def on_resolved(self) -> None: class SinkClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = "sink" # type: ignore credentials: SinkClientCredentials = None - loader_file_format: TLoaderFileFormat = "jsonl" + loader_file_format: TLoaderFileFormat = "puae-jsonl" batch_size: int = 10 if TYPE_CHECKING: @@ -71,6 +71,6 @@ def __init__( self, *, credentials: Union[SinkClientCredentials, TSinkCallable, str] = None, - loader_file_format: TLoaderFileFormat = "jsonl", + loader_file_format: TLoaderFileFormat = "puae-jsonl", batch_size: int = 10, ) -> None: ... diff --git a/dlt/destinations/impl/sink/factory.py b/dlt/destinations/impl/sink/factory.py index f51c3386ab..e65185cb8b 100644 --- a/dlt/destinations/impl/sink/factory.py +++ b/dlt/destinations/impl/sink/factory.py @@ -18,7 +18,7 @@ class sink(Destination[SinkClientConfiguration, "SinkClient"]): spec = SinkClientConfiguration def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities(self.config_params.get("loader_file_format", "parquet")) + return capabilities(self.config_params.get("loader_file_format", "puae-jsonl")) @property def client_class(self) -> t.Type["SinkClient"]: diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 3f577e0dca..2ebfefe516 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -8,7 +8,6 @@ from dlt.pipeline.current import ( destination_state, commit_load_package_state, - clear_destination_state, ) from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -32,7 +31,7 @@ def __init__( file_path: str, config: SinkClientConfiguration, schema: Schema, - load_package_state: Dict[str, int], + destination_state: Dict[str, int], ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path @@ -43,11 +42,11 @@ def __init__( self._state: TLoadJobState = "running" self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" try: - current_index = load_package_state.get(self._storage_id, 0) + current_index = destination_state.get(self._storage_id, 0) for batch in self.run(current_index): self.call_callable_with_items(batch) current_index += len(batch) - load_package_state[self._storage_id] = current_index + destination_state[self._storage_id] = current_index self._state = "completed" except Exception as e: @@ -64,20 +63,8 @@ def run(self, start_index: int) -> Iterable[TDataItems]: def call_callable_with_items(self, items: TDataItems) -> None: if not items: return - - # coerce items into correct format specified by schema - coerced_items: TDataItems = [] - for item in items: - coerced_item, table_update = self._schema.coerce_row(self._table["name"], None, item) - assert not table_update - coerced_items.append(coerced_item) - - # send single item on batch size 1 - if self._config.batch_size == 1: - coerced_items = coerced_items[0] - # call callable - self._config.credentials.resolved_callable(coerced_items, self._table) + self._config.credentials.resolved_callable(items, self._table) def state(self) -> TLoadJobState: return self._state @@ -102,8 +89,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: if start_batch > 0: start_batch -= 1 continue - batch = record_batch.to_pylist() - yield batch + yield record_batch class SinkJsonlLoadJob(SinkLoadJob): @@ -112,12 +98,14 @@ def run(self, start_index: int) -> Iterable[TDataItems]: # stream items with FileStorage.open_zipsafe_ro(self._file_path) as f: - for line in f: + encoded_json = json.typed_loads(f.read()) + + for item in encoded_json: # find correct start position if start_index > 0: start_index -= 1 continue - current_batch.append(json.loads(line)) + current_batch.append(item) if len(current_batch) == self._config.batch_size: yield current_batch current_batch = [] diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 53b043dcc2..a6516ea0b3 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -427,6 +427,7 @@ def normalize( """Normalizes the data prepared with `extract` method, infers the schema and creates load packages for the `load` method. Requires `destination` to be known.""" if is_interactive(): workers = 1 + if loader_file_format and loader_file_format in INTERNAL_LOADER_FILE_FORMATS: raise ValueError(f"{loader_file_format} is one of internal dlt file formats.") # check if any schema is present, if not then no data was extracted diff --git a/docs/website/docs/dlt-ecosystem/destinations/sink.md b/docs/website/docs/dlt-ecosystem/destinations/sink.md index 161132b902..87e2abb502 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/sink.md +++ b/docs/website/docs/dlt-ecosystem/destinations/sink.md @@ -1,7 +1,7 @@ --- -title: Sink / Reverse ETL +title: Destination Decorator / Reverse ETL description: Sink function `dlt` destination for reverse ETL -keywords: [reverse etl, sink, function] +keywords: [reverse etl, sink, function, decorator, destination] --- # Sink function / Reverse ETL @@ -18,28 +18,28 @@ pip install dlt Let's start by initializing a new dlt project as follows: ```bash -dlt init chess sink +dlt init chess decorator ``` -> 💡 This command will initialize your pipeline with chess as the source and Sink as the destination. +> 💡 This command will initialize your pipeline with chess as the source and decorator as the destination. The above command generates several files and directories, including `.dlt/secrets.toml`. -### 2. Set up a sink function for your pipeline -The sink destination differs from other destinations in that you do not need to provide connection credentials, but rather you provide a function which +### 2. Set up a destination function for your pipeline +The destination decorator differs from other destinations in that you do not need to provide connection credentials, but rather you provide a function which gets called for all items loaded during a pipeline run or load operation. For the chess example, you can add the following lines at the top of the file. -With the @dlt.sink decorator you can convert any function that takes two arguments into a dlt destination. +With the @dlt.destination decorator you can convert any function that takes two arguments into a dlt destination. ```python from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema -@dlt.sink(batch_size=10) +@dlt.destination(batch_size=10) def sink(items: TDataItems, table: TTableSchema) -> None: print(table["name"]) print(items) ``` -To enable this sink destination in your chess example, replace the line `destination='sink'` with `destination=sink` (without the quotes) to directly reference +To enable this destination decorator in your chess example, replace the line `destination='sink'` with `destination=sink` (without the quotes) to directly reference the sink from your pipeline constructor. Now you can run your pipeline and see the output of all the items coming from the chess pipeline to your console. :::tip @@ -52,14 +52,13 @@ the sink from your pipeline constructor. Now you can run your pipeline and see t The full signature of the sink decorator and a function is ```python -@dlt.sink(batch_size=10, loader_file_format="jsonl", name="my_sink") +@dlt.destination(batch_size=10, loader_file_format="jsonl", name="my_sink") def sink(items: TDataItems, table: TTableSchema) -> None: ... ``` #### Decorator -* The `batch_size` parameter on the sink decorator defines how many items per function call are batched together and sent as an array. If batch_size is set to one, -there will be one item without an array per call. +* The `batch_size` parameter on the sink decorator defines how many items per function call are batched together and sent as an array. * The `loader_file_format` parameter on the sink decorator defines in which format files are stored in the load package before being sent to the sink function, this can be `jsonl` or `parquet`. * The `name` parameter on the sink decorator defines the name of the destination that get's created by the sink decorator. @@ -85,7 +84,7 @@ There are multiple ways to reference the sink function you want to use. These ar ```python # file my_pipeline.py -@dlt.sink(batch_size=10) +@dlt.destination(batch_size=10) def local_sink_func(items: TDataItems, table: TTableSchema) -> None: ... diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index 3baf3e2fe1..a33613923e 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -71,7 +71,7 @@ def test_create_and_update_loadpackage_state(load_storage: LoadStorage) -> None: state = load_storage.new_packages.get_load_package_state("copy") assert state["_state_version"] == 0 assert state["_version_hash"] is not None - assert state["created"] is not None + assert state["created_at"] is not None old_state = state.copy() state["new_key"] = "new_value" # type: ignore @@ -82,10 +82,10 @@ def test_create_and_update_loadpackage_state(load_storage: LoadStorage) -> None: assert state["_state_version"] == 1 assert state["_version_hash"] != old_state["_version_hash"] # created timestamp should be conserved - assert state["created"] == old_state["created"] + assert state["created_at"] == old_state["created_at"] # check timestamp - time = pendulum.from_timestamp(state["created"]) + time = pendulum.parse(state["created_at"]) now = pendulum.now() assert (now - time).in_seconds() < 2 @@ -139,7 +139,7 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None clear_destination_state() assert ( load_storage.new_packages.get_load_package_state("copy").get("destination_state") - == None + is None ) diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 080c3e3c48..c1ae8153c2 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -19,7 +19,7 @@ delete_dataset, ) -SUPPORTED_LOADER_FORMATS = ["parquet", "jsonl"] +SUPPORTED_LOADER_FORMATS = ["parquet", "puae-jsonl"] def _run_through_sink( @@ -34,11 +34,14 @@ def _run_through_sink( """ calls: List[Tuple[TDataItems, TTableSchema]] = [] - @dlt.sink(loader_file_format=loader_file_format, batch_size=batch_size) + @dlt.destination(loader_file_format=loader_file_format, batch_size=batch_size) def test_sink(items: TDataItems, table: TTableSchema) -> None: nonlocal calls if table["name"].startswith("_dlt") and filter_dlt_tables: return + # convert pyarrow table to dict list here to make tests more simple downstream + if loader_file_format == "parquet": + items = items.to_pylist() # type: ignore calls.append((items, table)) @dlt.resource(columns=columns, table_name="items") @@ -67,16 +70,14 @@ def test_all_datatypes(loader_file_format: TLoaderFileFormat) -> None: # inspect result assert len(sink_calls) == 3 - item = sink_calls[0][0] + item = sink_calls[0][0][0] + # filter out _dlt columns - item = {k: v for k, v in item.items() if not k.startswith("_dlt")} # type: ignore + item = {k: v for k, v in item.items() if not k.startswith("_dlt")} # null values are not emitted data_types = {k: v for k, v in data_types.items() if v is not None} - # check keys are the same - assert set(item.keys()) == set(data_types.keys()) - assert_all_data_types_row(item, expect_filtered_null_columns=True) @@ -90,7 +91,7 @@ def test_batch_size(loader_file_format: TLoaderFileFormat, batch_size: int) -> N if batch_size == 1: assert len(sink_calls) == 100 # one item per call - assert sink_calls[0][0].items() > {"id": 0, "value": "0"}.items() # type: ignore + assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items() elif batch_size == 10: assert len(sink_calls) == 10 # ten items in first call @@ -106,8 +107,6 @@ def test_batch_size(loader_file_format: TLoaderFileFormat, batch_size: int) -> N all_items = set() for call in sink_calls: item = call[0] - if batch_size == 1: - item = [item] for entry in item: all_items.add(entry["value"]) @@ -137,7 +136,7 @@ def local_sink_func(items: TDataItems, table: TTableSchema) -> None: # test decorator calls = [] - p = dlt.pipeline("sink_test", destination=dlt.sink()(local_sink_func), full_refresh=True) + p = dlt.pipeline("sink_test", destination=dlt.destination()(local_sink_func), full_refresh=True) p.run([1, 2, 3], table_name="items") assert len(calls) == 1 @@ -184,27 +183,31 @@ def local_sink_func(items: TDataItems, table: TTableSchema) -> None: p.run([1, 2, 3], table_name="items") -@pytest.mark.parametrize("loader_file_format", ["jsonl"]) +@pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) @pytest.mark.parametrize("batch_size", [1, 10, 23]) def test_batched_transactions(loader_file_format: TLoaderFileFormat, batch_size: int) -> None: calls: Dict[str, List[TDataItems]] = {} # provoke errors on resources provoke_error: Dict[str, int] = {} - @dlt.sink(loader_file_format=loader_file_format, batch_size=batch_size) + @dlt.destination(loader_file_format=loader_file_format, batch_size=batch_size) def test_sink(items: TDataItems, table: TTableSchema) -> None: nonlocal calls table_name = table["name"] if table_name.startswith("_dlt"): return + # convert pyarrow table to dict list here to make tests more simple downstream + if loader_file_format == "parquet": + items = items.to_pylist() # type: ignore + # provoke error if configured if table_name in provoke_error: - for item in items if batch_size > 1 else [items]: + for item in items: if provoke_error[table_name] == item["id"]: raise AssertionError("Oh no!") - calls.setdefault(table_name, []).append(items if batch_size > 1 else [items]) + calls.setdefault(table_name, []).append(items) @dlt.resource() def items() -> TDataItems: From 27b8b2cf6dd031e2a9bd5b9d3157fd7b0c4cd223 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 8 Feb 2024 13:13:45 +0100 Subject: [PATCH 24/35] small readme changes --- docs/website/docs/dlt-ecosystem/destinations/sink.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/sink.md b/docs/website/docs/dlt-ecosystem/destinations/sink.md index 87e2abb502..ccd4a1a59f 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/sink.md +++ b/docs/website/docs/dlt-ecosystem/destinations/sink.md @@ -18,14 +18,14 @@ pip install dlt Let's start by initializing a new dlt project as follows: ```bash -dlt init chess decorator +dlt init chess sink ``` -> 💡 This command will initialize your pipeline with chess as the source and decorator as the destination. +> 💡 This command will initialize your pipeline with chess as the source and sink as the destination. The above command generates several files and directories, including `.dlt/secrets.toml`. ### 2. Set up a destination function for your pipeline -The destination decorator differs from other destinations in that you do not need to provide connection credentials, but rather you provide a function which +The sink destination differs from other destinations in that you do not need to provide connection credentials, but rather you provide a function which gets called for all items loaded during a pipeline run or load operation. For the chess example, you can add the following lines at the top of the file. With the @dlt.destination decorator you can convert any function that takes two arguments into a dlt destination. From 3229745c872c503bc91707b002a3f9a74583402d Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 4 Mar 2024 17:35:31 +0100 Subject: [PATCH 25/35] add load id to loadpackage info in current --- dlt/common/storages/load_package.py | 20 ++++++++++++++------ dlt/destinations/impl/sink/__init__.py | 2 ++ dlt/extract/incremental/__init__.py | 1 - dlt/pipeline/current.py | 2 +- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index d035ddd550..f946f33113 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -22,6 +22,7 @@ Any, Tuple, TYPE_CHECKING, + TypedDict, ) from dlt.common import pendulum, json @@ -58,6 +59,13 @@ class TLoadPackageState(TVersionedState, total=False): """private space for destinations to store state relevant only to the load package""" +class TLoadPackage(TypedDict, total=False): + load_id: str + """Load id""" + state: TLoadPackageState + """State of the load package""" + + # allows to upgrade state when restored with a new version of state logic/schema LOADPACKAGE_STATE_ENGINE_VERSION = 1 @@ -632,7 +640,7 @@ def on_resolved(self) -> None: def __init__(self, load_id: str, storage: PackageStorage) -> None: ... -def load_package_state() -> TLoadPackageState: +def load_package() -> TLoadPackage: """Get full load package state present in current context. Across all threads this will be the same in memory dict.""" container = Container() # get injected state if present. injected load package state is typically "managed" so changes will be persisted @@ -641,7 +649,7 @@ def load_package_state() -> TLoadPackageState: state_ctx = container[LoadPackageStateInjectableContext] except ContextDefaultCannotBeCreated: raise Exception("Load package state not available") - return state_ctx.state + return TLoadPackage(state=state_ctx.state, load_id=state_ctx.load_id) def commit_load_package_state() -> None: @@ -656,13 +664,13 @@ def commit_load_package_state() -> None: def destination_state() -> DictStrAny: """Get segment of load package state that is specific to the current destination.""" - lp_state = load_package_state() - return lp_state.setdefault("destination_state", {}) + lp = load_package() + return lp["state"].setdefault("destination_state", {}) def clear_destination_state(commit: bool = True) -> None: """Clear segment of load package state that is specific to the current destination. Optionally commit to load package.""" - lp_state = load_package_state() - lp_state.pop("destination_state", None) + lp = load_package() + lp["state"].pop("destination_state", None) if commit: commit_load_package_state() diff --git a/dlt/destinations/impl/sink/__init__.py b/dlt/destinations/impl/sink/__init__.py index 72e93622c2..fbad2d570f 100644 --- a/dlt/destinations/impl/sink/__init__.py +++ b/dlt/destinations/impl/sink/__init__.py @@ -4,9 +4,11 @@ def capabilities( preferred_loader_file_format: TLoaderFileFormat = "puae-jsonl", + naming_convention: str = "direct", ) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) caps.supported_loader_file_formats = ["puae-jsonl", "parquet"] caps.supports_ddl_transactions = False caps.supports_transactions = False + caps.naming_convention = naming_convention return caps diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index d1a5a05c34..24495ccb19 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -7,7 +7,6 @@ from functools import wraps - import dlt from dlt.common.exceptions import MissingDependencyException from dlt.common import pendulum, logger diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index f00f7c1ec4..25fd398623 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -4,7 +4,7 @@ from dlt.pipeline import pipeline as _pipeline from dlt.extract.decorators import get_source_schema from dlt.common.storages.load_package import ( - load_package_state, + load_package, commit_load_package_state, destination_state, clear_destination_state, From dbbbe7c9f9432955bfaec8e00dc8f1831521a424 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 4 Mar 2024 17:46:15 +0100 Subject: [PATCH 26/35] add support for directly passing through the naming convention to the sink --- dlt/destinations/decorators.py | 11 ++++++++-- dlt/destinations/impl/sink/factory.py | 7 ++++++- tests/load/sink/test_sink.py | 30 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index a21e8eaca8..cbeeff8975 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -7,14 +7,21 @@ def sink( - loader_file_format: TLoaderFileFormat = None, batch_size: int = 10, name: str = None + loader_file_format: TLoaderFileFormat = None, + batch_size: int = 10, + name: str = None, + naming_convention: str = "direct", ) -> Any: def decorator(f: TSinkCallable) -> TDestinationReferenceArg: nonlocal name if name is None: name = get_callable_name(f) return _sink( - credentials=f, loader_file_format=loader_file_format, batch_size=batch_size, name=name + credentials=f, + loader_file_format=loader_file_format, + batch_size=batch_size, + name=name, + naming_convention=naming_convention, ) return decorator diff --git a/dlt/destinations/impl/sink/factory.py b/dlt/destinations/impl/sink/factory.py index e65185cb8b..6b2e98271e 100644 --- a/dlt/destinations/impl/sink/factory.py +++ b/dlt/destinations/impl/sink/factory.py @@ -18,7 +18,10 @@ class sink(Destination[SinkClientConfiguration, "SinkClient"]): spec = SinkClientConfiguration def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities(self.config_params.get("loader_file_format", "puae-jsonl")) + return capabilities( + self.config_params.get("loader_file_format", "puae-jsonl"), + self.config_params.get("naming_convention", "direct"), + ) @property def client_class(self) -> t.Type["SinkClient"]: @@ -33,6 +36,7 @@ def __init__( environment: t.Optional[str] = None, loader_file_format: TLoaderFileFormat = None, batch_size: int = 10, + naming_convention: str = "direct", **kwargs: t.Any, ) -> None: super().__init__( @@ -41,5 +45,6 @@ def __init__( environment=environment, loader_file_format=loader_file_format, batch_size=batch_size, + naming_convention=naming_convention, **kwargs, ) diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index c1ae8153c2..72dcfd5b1e 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -288,3 +288,33 @@ def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: # both calls combined should have every item called just once assert_items_in_range(calls["items"] + first_calls["items"], 0, 100) assert_items_in_range(calls["items2"] + first_calls["items2"], 0, 100) + + +def test_naming_convention() -> None: + @dlt.resource(table_name="PErson") + def resource(): + yield [{"UpperCase": 1, "snake_case": 1, "camelCase": 1}] + + # check snake case + @dlt.destination(naming_convention="snake_case") + def snake_sink(items, table): + if table["name"].startswith("_dlt"): + return + assert table["name"] == "p_erson" + assert table["columns"]["upper_case"]["name"] == "upper_case" + assert table["columns"]["snake_case"]["name"] == "snake_case" + assert table["columns"]["camel_case"]["name"] == "camel_case" + + dlt.pipeline("sink_test", destination=snake_sink, full_refresh=True).run(resource()) + + # check default (which is direct) + @dlt.destination() + def direct_sink(items, table): + if table["name"].startswith("_dlt"): + return + assert table["name"] == "PErson" + assert table["columns"]["UpperCase"]["name"] == "UpperCase" + assert table["columns"]["snake_case"]["name"] == "snake_case" + assert table["columns"]["camelCase"]["name"] == "camelCase" + + dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource()) From db9b48889989cf9cf04a103f5dcb3b997c40f236 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 4 Mar 2024 18:02:23 +0100 Subject: [PATCH 27/35] add support for batch size zero (filepath passthrouh) --- dlt/destinations/impl/sink/configuration.py | 2 +- dlt/destinations/impl/sink/sink.py | 14 +++++++---- tests/load/sink/test_sink.py | 27 +++++++++++++++++++++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index 9a96aea98d..8d9289ff8b 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -12,7 +12,7 @@ from dlt.common.configuration.exceptions import ConfigurationValueError -TSinkCallable = Callable[[TDataItems, TTableSchema], None] +TSinkCallable = Callable[[Union[TDataItems, str], TTableSchema], None] @configspec diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 2ebfefe516..816eece079 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -42,11 +42,15 @@ def __init__( self._state: TLoadJobState = "running" self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" try: - current_index = destination_state.get(self._storage_id, 0) - for batch in self.run(current_index): - self.call_callable_with_items(batch) - current_index += len(batch) - destination_state[self._storage_id] = current_index + if self._config.batch_size == 0: + # on batch size zero we only call the callable with the filename + self.call_callable_with_items(self._file_path) + else: + current_index = destination_state.get(self._storage_id, 0) + for batch in self.run(current_index): + self.call_callable_with_items(batch) + current_index += len(batch) + destination_state[self._storage_id] = current_index self._state = "completed" except Exception as e: diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 72dcfd5b1e..f5cf318ee8 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -318,3 +318,30 @@ def direct_sink(items, table): assert table["columns"]["camelCase"]["name"] == "camelCase" dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource()) + + +def test_file_batch() -> None: + @dlt.resource(table_name="person") + def resource1(): + for i in range(100): + yield [{"id": i, "name": f"Name {i}"}] + + @dlt.resource(table_name="address") + def resource2(): + for i in range(50): + yield [{"id": i, "city": f"City {i}"}] + + @dlt.destination(batch_size=0, loader_file_format="parquet") + def direct_sink(file_path, table): + if table["name"].startswith("_dlt"): + return + from dlt.common.libs.pyarrow import pyarrow + + assert table["name"] in ["person", "address"] + + with pyarrow.parquet.ParquetFile(file_path) as reader: + assert reader.metadata.num_rows == (100 if table["name"] == "person" else 50) + + dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( + [resource1(), resource2()] + ) From 3c39f418fa89c90a2ce5bd3353efa50b31597ada Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 5 Mar 2024 16:20:31 +0100 Subject: [PATCH 28/35] use patched version of flak8 encoding --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0eb33d885d..5da0b17400 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,7 +134,7 @@ types-simplejson = ">=3.17.0" types-requests = ">=2.25.6" types-python-dateutil = ">=2.8.15" flake8-tidy-imports = ">=4.8.0" -flake8-encodings = "^0.5.0" +flake8-encodings = { git = "git@github.com:dlt-hub/flake8-encodings.git", branch = "disable_jedi_support" } flake8-builtins = "^1.5.3" boto3-stubs = "^1.28.28" types-tqdm = "^4.66.0.2" From 3dfcf39a4213f4c6398b57763a75ca23991cc9de Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 5 Mar 2024 16:24:54 +0100 Subject: [PATCH 29/35] fix tests --- docs/examples/connector_x_arrow/load_arrow.py | 2 ++ docs/examples/google_sheets/google_sheets.py | 5 ++++- docs/examples/incremental_loading/zendesk.py | 8 ++++---- docs/examples/nested_data/nested_data.py | 2 ++ docs/examples/pdf_to_weaviate/pdf_to_weaviate.py | 5 ++++- docs/examples/qdrant_zendesk/qdrant.py | 9 +++++---- docs/examples/transformers/pokemon.py | 4 +++- tests/common/storages/test_load_package.py | 14 +++++++------- 8 files changed, 31 insertions(+), 18 deletions(-) diff --git a/docs/examples/connector_x_arrow/load_arrow.py b/docs/examples/connector_x_arrow/load_arrow.py index 06ca4e17b3..b3c654cef9 100644 --- a/docs/examples/connector_x_arrow/load_arrow.py +++ b/docs/examples/connector_x_arrow/load_arrow.py @@ -3,6 +3,7 @@ import dlt from dlt.sources.credentials import ConnectionStringCredentials + def read_sql_x( conn_str: ConnectionStringCredentials = dlt.secrets.value, query: str = dlt.config.value, @@ -14,6 +15,7 @@ def read_sql_x( protocol="binary", ) + def genome_resource(): # create genome resource with merge on `upid` primary key genome = dlt.resource( diff --git a/docs/examples/google_sheets/google_sheets.py b/docs/examples/google_sheets/google_sheets.py index 8a93df9970..1ba330e4ca 100644 --- a/docs/examples/google_sheets/google_sheets.py +++ b/docs/examples/google_sheets/google_sheets.py @@ -9,6 +9,7 @@ ) from dlt.common.typing import DictStrAny, StrAny + def _initialize_sheets( credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials] ) -> Any: @@ -16,6 +17,7 @@ def _initialize_sheets( service = build("sheets", "v4", credentials=credentials.to_native_credentials()) return service + @dlt.source def google_spreadsheet( spreadsheet_id: str, @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: for name in sheet_names ] + if __name__ == "__main__": pipeline = dlt.pipeline(destination="duckdb") # see example.secrets.toml to where to put credentials @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: sheet_names=range_names, ) ) - print(info) \ No newline at end of file + print(info) diff --git a/docs/examples/incremental_loading/zendesk.py b/docs/examples/incremental_loading/zendesk.py index 4b8597886a..6113f98793 100644 --- a/docs/examples/incremental_loading/zendesk.py +++ b/docs/examples/incremental_loading/zendesk.py @@ -6,12 +6,11 @@ from dlt.common.typing import TAnyDateTime from dlt.sources.helpers.requests import client + @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -113,6 +112,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create dlt pipeline pipeline = dlt.pipeline( @@ -120,4 +120,4 @@ def get_pages( ) load_info = pipeline.run(zendesk_support()) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/nested_data/nested_data.py b/docs/examples/nested_data/nested_data.py index 3464448de6..7f85f0522e 100644 --- a/docs/examples/nested_data/nested_data.py +++ b/docs/examples/nested_data/nested_data.py @@ -13,6 +13,7 @@ CHUNK_SIZE = 10000 + # You can limit how deep dlt goes when generating child tables. # By default, the library will descend and generate child tables # for all nested lists, without a limit. @@ -81,6 +82,7 @@ def load_documents(self) -> Iterator[TDataItem]: while docs_slice := list(islice(cursor, CHUNK_SIZE)): yield map_nested_in_place(convert_mongo_objs, docs_slice) + def convert_mongo_objs(value: Any) -> Any: if isinstance(value, (ObjectId, Decimal128)): return str(value) diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 8f7833e7d7..e7f57853ed 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -4,6 +4,7 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from PyPDF2 import PdfReader + @dlt.resource(selected=False) def list_files(folder_path: str): folder_path = os.path.abspath(folder_path) @@ -15,6 +16,7 @@ def list_files(folder_path: str): "mtime": os.path.getmtime(file_path), } + @dlt.transformer(primary_key="page_id", write_disposition="merge") def pdf_to_text(file_item, separate_pages: bool = False): if not separate_pages: @@ -28,6 +30,7 @@ def pdf_to_text(file_item, separate_pages: bool = False): page_item["page_id"] = file_item["file_name"] + "_" + str(page_no) yield page_item + pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate") # this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf" @@ -51,4 +54,4 @@ def pdf_to_text(file_item, separate_pages: bool = False): client = weaviate.Client("http://localhost:8080") # get text of all the invoices in InvoiceText class we just created above -print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) \ No newline at end of file +print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) diff --git a/docs/examples/qdrant_zendesk/qdrant.py b/docs/examples/qdrant_zendesk/qdrant.py index 300d8dc6ad..bd0cbafc99 100644 --- a/docs/examples/qdrant_zendesk/qdrant.py +++ b/docs/examples/qdrant_zendesk/qdrant.py @@ -10,13 +10,12 @@ from dlt.common.configuration.inject import with_config + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -80,6 +79,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: return None return ensure_pendulum_datetime(value) + # modify dates to return datetime objects instead def _fix_date(ticket): ticket["updated_at"] = _parse_date_or_none(ticket["updated_at"]) @@ -87,6 +87,7 @@ def _fix_date(ticket): ticket["due_at"] = _parse_date_or_none(ticket["due_at"]) return ticket + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk def get_pages( url: str, @@ -127,6 +128,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create a pipeline with an appropriate name pipeline = dlt.pipeline( @@ -146,7 +148,6 @@ def get_pages( print(load_info) - # running the Qdrant client to connect to your Qdrant database @with_config(sections=("destination", "qdrant", "credentials")) diff --git a/docs/examples/transformers/pokemon.py b/docs/examples/transformers/pokemon.py index c17beff6a8..97b9a98b11 100644 --- a/docs/examples/transformers/pokemon.py +++ b/docs/examples/transformers/pokemon.py @@ -1,6 +1,7 @@ import dlt from dlt.sources.helpers import requests + @dlt.source(max_table_nesting=2) def source(pokemon_api_url: str): """""" @@ -46,6 +47,7 @@ def species(pokemon_details): return (pokemon_list | pokemon, pokemon_list | pokemon | species) + if __name__ == "__main__": # build duck db pipeline pipeline = dlt.pipeline( @@ -54,4 +56,4 @@ def species(pokemon_details): # the pokemon_list resource does not need to be loaded load_info = pipeline.run(source("https://pokeapi.co/api/v2/pokemon")) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index a33613923e..8d47c22615 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -14,7 +14,7 @@ from dlt.common.storages.load_package import ( LoadPackageStateInjectableContext, destination_state, - load_package_state, + load_package, commit_load_package_state, clear_destination_state, ) @@ -87,7 +87,7 @@ def test_create_and_update_loadpackage_state(load_storage: LoadStorage) -> None: # check timestamp time = pendulum.parse(state["created_at"]) now = pendulum.now() - assert (now - time).in_seconds() < 2 + assert (now - time).in_seconds() < 2 # type: ignore def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None: @@ -101,9 +101,9 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None ) ): # test general load package state - injected_state = load_package_state() - assert injected_state["_state_version"] == 0 - injected_state["new_key"] = "new_value" # type: ignore + injected_state = load_package() + assert injected_state["state"]["_state_version"] == 0 + injected_state["state"]["new_key"] = "new_value" # type: ignore # not persisted yet assert load_storage.new_packages.get_load_package_state("copy").get("new_key") is None @@ -117,7 +117,7 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None assert load_storage.new_packages.get_load_package_state("copy").get("_state_version") == 1 # check that second injection is the same as first - second_injected_instance = load_package_state() + second_injected_instance = load_package() assert second_injected_instance == injected_state # check scoped destination states @@ -133,7 +133,7 @@ def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None ) == {"new_key": "new_value"} # this also shows up on the previously injected state - assert injected_state["destination_state"]["new_key"] == "new_value" + assert injected_state["state"]["destination_state"]["new_key"] == "new_value" # clear destination state clear_destination_state() From bc44618189f6b606dc1f7613a93d04df7026d653 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 5 Mar 2024 17:28:09 +0100 Subject: [PATCH 30/35] add support for secrets and config in sink --- dlt/destinations/decorators.py | 2 +- dlt/destinations/impl/sink/configuration.py | 5 --- dlt/destinations/impl/sink/sink.py | 19 +++++++-- tests/load/sink/test_sink.py | 44 ++++++++++++++++++++- 4 files changed, 60 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index cbeeff8975..3280bf7378 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -20,7 +20,7 @@ def decorator(f: TSinkCallable) -> TDestinationReferenceArg: credentials=f, loader_file_format=loader_file_format, batch_size=batch_size, - name=name, + destination_name=name, naming_convention=naming_convention, ) diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index 8d9289ff8b..1b1ed8f893 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -18,8 +18,6 @@ @configspec class SinkClientCredentials(CredentialsConfiguration): callable: Optional[str] = None # noqa: A003 - # name provides namespace for callable state saving - name: Optional[str] = None def parse_native_representation(self, native_value: Any) -> None: # a callable was passed in @@ -54,9 +52,6 @@ def on_resolved(self) -> None: if not callable(self.resolved_callable): raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") - if not self.name: - self.name = self.resolved_callable.__name__ - @configspec class SinkClientConfiguration(DestinationClientConfiguration): diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 816eece079..a13f55551b 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from types import TracebackType from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable +from dlt.common.configuration import with_config, known_sections from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems @@ -32,12 +33,14 @@ def __init__( config: SinkClientConfiguration, schema: Schema, destination_state: Dict[str, int], + resolved_callable: TSinkCallable, ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config self._table = table self._schema = schema + self._resolved_callable = resolved_callable self._state: TLoadJobState = "running" self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" @@ -68,7 +71,7 @@ def call_callable_with_items(self, items: TDataItems) -> None: if not items: return # call callable - self._config.credentials.resolved_callable(items, self._table) + self._resolved_callable(items, self._table) def state(self) -> TLoadJobState: return self._state @@ -125,6 +128,12 @@ def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: super().__init__(schema, config) self.config: SinkClientConfiguration = config + # inject config values + self.resolved_callable = with_config( + self.config.credentials.resolved_callable, + sections=(known_sections.DESTINATION, self.config.destination_name), + ) + def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: pass @@ -143,9 +152,13 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> # save our state in destination name scope load_state = destination_state() if file_path.endswith("parquet"): - return SinkParquetLoadJob(table, file_path, self.config, self.schema, load_state) + return SinkParquetLoadJob( + table, file_path, self.config, self.schema, load_state, self.resolved_callable + ) if file_path.endswith("jsonl"): - return SinkJsonlLoadJob(table, file_path, self.config, self.schema, load_state) + return SinkJsonlLoadJob( + table, file_path, self.config, self.schema, load_state, self.resolved_callable + ) return None def restore_file_load(self, file_path: str) -> LoadJob: diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index f5cf318ee8..d5d6085d60 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -3,6 +3,7 @@ import dlt import pytest import pytest +import os from copy import deepcopy from dlt.common.typing import TDataItems @@ -16,7 +17,6 @@ TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, assert_all_data_types_row, - delete_dataset, ) SUPPORTED_LOADER_FORMATS = ["parquet", "puae-jsonl"] @@ -345,3 +345,45 @@ def direct_sink(file_path, table): dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( [resource1(), resource2()] ) + + +def test_config_spec() -> None: + @dlt.destination() + def my_sink(file_path, table, my_val=dlt.config.value): + assert my_val == "something" + + # if no value is present, it should raise + with pytest.raises(PipelineStepFailed) as exc: + dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # right value will pass + os.environ["DESTINATION__MY_SINK__MY_VAL"] = "something" + dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # wrong value will raise + os.environ["DESTINATION__MY_SINK__MY_VAL"] = "wrong" + with pytest.raises(PipelineStepFailed) as exc: + dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # will respect given name + @dlt.destination(name="some_name") + def other_sink(file_path, table, my_val=dlt.config.value): + assert my_val == "something" + + # if no value is present, it should raise + with pytest.raises(PipelineStepFailed) as exc: + dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # right value will pass + os.environ["DESTINATION__SOME_NAME__MY_VAL"] = "something" + dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) From db8d0ed15787958c9200c447e45d8c5349cf918d Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 5 Mar 2024 17:28:34 +0100 Subject: [PATCH 31/35] update sink docs --- .../docs/dlt-ecosystem/destinations/sink.md | 56 ++++++++++++++++--- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/sink.md b/docs/website/docs/dlt-ecosystem/destinations/sink.md index ccd4a1a59f..577d79ee24 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/sink.md +++ b/docs/website/docs/dlt-ecosystem/destinations/sink.md @@ -4,7 +4,13 @@ description: Sink function `dlt` destination for reverse ETL keywords: [reverse etl, sink, function, decorator, destination] --- -# Sink function / Reverse ETL +# Destination decorator / Reverse ETL + +The dlt destination decorator allows you to receive all data passing through your pipeline in a simple function. This can be extremely useful for +reverse ETL, where you are pushing data back to an api. You can also use this for sending data to a queue or a simple database destination that is not +yet supported by dlt, be aware that you will have to manually handle your own migrations in this case. It will also allow you to simply get a path +to the files of your normalized data, so if you need direct access to parquet or jsonl files to copy them somewhere or push them to a database, +you can do this here too. ## Install dlt for Sink / reverse ETL ** To install the DLT without additional dependencies ** @@ -27,16 +33,22 @@ The above command generates several files and directories, including `.dlt/secre ### 2. Set up a destination function for your pipeline The sink destination differs from other destinations in that you do not need to provide connection credentials, but rather you provide a function which gets called for all items loaded during a pipeline run or load operation. For the chess example, you can add the following lines at the top of the file. -With the @dlt.destination decorator you can convert any function that takes two arguments into a dlt destination. +With the `@dlt.destination` decorator you can convert + +A very simple dlt pipeline that pushes a list of items into a sink function might look like this: ```python from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema @dlt.destination(batch_size=10) -def sink(items: TDataItems, table: TTableSchema) -> None: +def my_sink(items: TDataItems, table: TTableSchema) -> None: print(table["name"]) print(items) + +pipe = dlt.pipeline("sink_pipeline", destination=my_sink) +pipe.run([1, 2, 3], table_name="items") + ``` To enable this destination decorator in your chess example, replace the line `destination='sink'` with `destination=sink` (without the quotes) to directly reference @@ -47,25 +59,41 @@ the sink from your pipeline constructor. Now you can run your pipeline and see t 2. There are a few other ways for declaring sink functions for your pipeline described below. ::: -## Sink decorator function and signature +## destination decorator function and signature -The full signature of the sink decorator and a function is +The full signature of the destination decorator plus its function is the following: ```python -@dlt.destination(batch_size=10, loader_file_format="jsonl", name="my_sink") +@dlt.destination(batch_size=10, loader_file_format="jsonl", name="my_sink", naming="direct") def sink(items: TDataItems, table: TTableSchema) -> None: ... ``` #### Decorator -* The `batch_size` parameter on the sink decorator defines how many items per function call are batched together and sent as an array. -* The `loader_file_format` parameter on the sink decorator defines in which format files are stored in the load package before being sent to the sink function, +* The `batch_size` parameter on the destination decorator defines how many items per function call are batched together and sent as an array. If you set a batch-size of `0`, +instead of passing in actual dataitems, you will receive one call per load job with the path of the file as the items argument. You can then open and process that file +in any way you like. +* The `loader_file_format` parameter on the destination decorator defines in which format files are stored in the load package before being sent to the sink function, this can be `jsonl` or `parquet`. -* The `name` parameter on the sink decorator defines the name of the destination that get's created by the sink decorator. +* The `name` parameter on the destination decorator defines the name of the destination that get's created by the destination decorator. +* The `naming_convention` parameter on the destination decorator defines the name of the destination that gets created by the destination decorator. This controls +how table and column names are normalized. The default is `direct` which will keep all names the same. #### Sink function * The `items` parameter on the sink function contains the items being sent into the sink function. * The `table` parameter contains the schema table the current call belongs to including all table hints and columns. For example the table name can be access with table["name"]. Keep in mind that dlt also created special tables prefixed with `_dlt` which you may want to ignore when processing data. +* You can also add config values and secrets to the function arguments, see below! + + +## Adding config variables and secrets +The destination decorator supports settings and secrets variables. If you, for example, plan to connect to a service that requires an api secret or a login, you can do the following: + +```python +@dlt.destination(batch_size=10, loader_file_format="jsonl", name="my_sink") +def my_sink(items: TDataItems, table: TTableSchema, api_key: dlt.secrets.value) -> None: + ... +``` + ## Sink destination state The sink destination keeps a local record of how many DataItems were processed, so if you, for example, use the sink destination to push DataItems to a remote api, and this @@ -73,6 +101,16 @@ api becomes unavailable during the load resulting in a failed dlt pipeline run, where it left of. For this reason it makes sense to choose a batch size that you can process in one transaction (say one api request or one database transaction) so that if this request or transaction fail repeatedly you can repeat it at the next run without pushing duplicate data to your remote location. + + +And add the api key to your toml: + +```toml +[destination.my_sink] +api_key="some secrets" +``` + + ## Concurrency Calls to the sink function by default will be executed on multiple threads, so you need to make sure you are not using any non-thread-safe nonlocal or global variables from outside your sink function. If, for whichever reason, you need to have all calls be executed from the same thread, you can set the `workers` config variable of the load step to 1. For performance From d8719c1f229ca482bd0d16f94bff4b19e7ccdfc8 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 5 Mar 2024 18:01:22 +0100 Subject: [PATCH 32/35] revert encodings branch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5da0b17400..0eb33d885d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,7 +134,7 @@ types-simplejson = ">=3.17.0" types-requests = ">=2.25.6" types-python-dateutil = ">=2.8.15" flake8-tidy-imports = ">=4.8.0" -flake8-encodings = { git = "git@github.com:dlt-hub/flake8-encodings.git", branch = "disable_jedi_support" } +flake8-encodings = "^0.5.0" flake8-builtins = "^1.5.3" boto3-stubs = "^1.28.28" types-tqdm = "^4.66.0.2" From d7eb19d90d921a9de252b828e63c3583c0c41c6b Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 5 Mar 2024 18:08:04 +0100 Subject: [PATCH 33/35] fix small linting problem --- tests/load/sink/test_sink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index d5d6085d60..2a602f1c2e 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -377,7 +377,7 @@ def other_sink(file_path, table, my_val=dlt.config.value): assert my_val == "something" # if no value is present, it should raise - with pytest.raises(PipelineStepFailed) as exc: + with pytest.raises(PipelineStepFailed): dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( [1, 2, 3], table_name="items" ) From ef35502622e77c4d58003663389b20aaf8b1773b Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 6 Mar 2024 15:48:04 +0100 Subject: [PATCH 34/35] add support for config specs --- dlt/__init__.py | 2 +- dlt/common/configuration/container.py | 12 +- dlt/common/configuration/inject.py | 5 +- dlt/common/reflection/spec.py | 7 +- dlt/common/storages/load_package.py | 1 + dlt/destinations/__init__.py | 4 +- dlt/destinations/decorators.py | 21 ++-- .../impl/{sink => destination}/__init__.py | 0 .../impl/destination/configuration.py | 31 +++++ dlt/destinations/impl/destination/factory.py | 111 ++++++++++++++++++ .../impl/{sink => destination}/sink.py | 33 +++--- dlt/destinations/impl/sink/configuration.py | 71 ----------- dlt/destinations/impl/sink/factory.py | 50 -------- tests/load/sink/test_sink.py | 41 +++---- tests/utils.py | 4 +- 15 files changed, 210 insertions(+), 183 deletions(-) rename dlt/destinations/impl/{sink => destination}/__init__.py (100%) create mode 100644 dlt/destinations/impl/destination/configuration.py create mode 100644 dlt/destinations/impl/destination/factory.py rename dlt/destinations/impl/{sink => destination}/sink.py (87%) delete mode 100644 dlt/destinations/impl/sink/configuration.py delete mode 100644 dlt/destinations/impl/sink/factory.py diff --git a/dlt/__init__.py b/dlt/__init__.py index c40416ba73..eee105e47e 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -29,7 +29,7 @@ from dlt import sources from dlt.extract.decorators import source, resource, transformer, defer -from dlt.destinations.decorators import sink as destination +from dlt.destinations.decorators import destination from dlt.pipeline import ( pipeline as _pipeline, diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index ad20765489..95a42e0087 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -87,10 +87,14 @@ def _thread_context( context = self.main_context else: # thread pool names used in dlt contain originating thread id. use this id over pool id - if m := re.match(r"dlt-pool-(\d+)-", threading.currentThread().getName()): - thread_id = int(m.group(1)) - else: - thread_id = threading.get_ident() + # print(threading.currentThread().getName()) + # if m := re.match(r"dlt-pool-(\d+)", threading.currentThread().getName()): + # thread_id = int(m.group(1)) + # print("MATCH") + # else: + + thread_id = threading.get_ident() + # return main context for main thread if thread_id == Container._MAIN_THREAD_ID: return self.main_context diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index a22f299ae8..9d862f8795 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -58,6 +58,7 @@ def with_config( include_defaults: bool = True, accept_partial: bool = False, initial_config: Optional[BaseConfiguration] = None, + base: Type[BaseConfiguration] = BaseConfiguration, ) -> Callable[[TFun], TFun]: """Injects values into decorated function arguments following the specification in `spec` or by deriving one from function's signature. @@ -75,6 +76,7 @@ def with_config( Returns: Callable[[TFun], TFun]: A decorated function """ + section_f: Callable[[StrAny], str] = None # section may be a function from function arguments to section if callable(sections): @@ -88,9 +90,8 @@ def decorator(f: TFun) -> TFun: ) spec_arg: Parameter = None pipeline_name_arg: Parameter = None - if spec is None: - SPEC = spec_from_signature(f, sig, include_defaults) + SPEC = spec_from_signature(f, sig, include_defaults, base=base) else: SPEC = spec diff --git a/dlt/common/reflection/spec.py b/dlt/common/reflection/spec.py index 0a486088c8..ffc12e908c 100644 --- a/dlt/common/reflection/spec.py +++ b/dlt/common/reflection/spec.py @@ -26,7 +26,10 @@ def _first_up(s: str) -> str: def spec_from_signature( - f: AnyFun, sig: Signature, include_defaults: bool = True + f: AnyFun, + sig: Signature, + include_defaults: bool = True, + base: Type[BaseConfiguration] = BaseConfiguration, ) -> Type[BaseConfiguration]: name = _get_spec_name_from_f(f) module = inspect.getmodule(f) @@ -109,7 +112,7 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType: # set annotations so they are present in __dict__ fields["__annotations__"] = annotations # synthesize type - T: Type[BaseConfiguration] = type(name, (BaseConfiguration,), fields) + T: Type[BaseConfiguration] = type(name, (base,), fields) SPEC = configspec()(T) # add to the module setattr(module, spec_id, SPEC) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index f946f33113..c59366d715 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -626,6 +626,7 @@ class LoadPackageStateInjectableContext(ContainerInjectableContext): storage: PackageStorage load_id: str can_create_default: ClassVar[bool] = False + global_affinity: ClassVar[bool] = True def commit(self) -> None: with self.state_save_lock: diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 4502b362d0..4a10deffc0 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -10,7 +10,7 @@ from dlt.destinations.impl.qdrant.factory import qdrant from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate -from dlt.destinations.impl.sink.factory import sink +from dlt.destinations.impl.destination.factory import destination from dlt.destinations.impl.synapse.factory import synapse from dlt.destinations.impl.databricks.factory import databricks @@ -30,5 +30,5 @@ "weaviate", "synapse", "databricks", - "sink", + "destination", ] diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index 3280bf7378..e6df02d077 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -1,23 +1,22 @@ -from typing import Any, Callable -from dlt.destinations.impl.sink.factory import sink as _sink -from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable +from typing import Any, Type +from dlt.destinations.impl.destination.factory import destination as _destination +from dlt.destinations.impl.destination.configuration import SinkClientConfiguration, TSinkCallable from dlt.common.destination import TDestinationReferenceArg from dlt.common.destination import TLoaderFileFormat -from dlt.common.utils import get_callable_name -def sink( +def destination( loader_file_format: TLoaderFileFormat = None, batch_size: int = 10, name: str = None, naming_convention: str = "direct", + spec: Type[SinkClientConfiguration] = SinkClientConfiguration, ) -> Any: - def decorator(f: TSinkCallable) -> TDestinationReferenceArg: - nonlocal name - if name is None: - name = get_callable_name(f) - return _sink( - credentials=f, + def decorator(destination_callable: TSinkCallable) -> TDestinationReferenceArg: + # return destination instance + return _destination( + spec=spec, + destination_callable=destination_callable, loader_file_format=loader_file_format, batch_size=batch_size, destination_name=name, diff --git a/dlt/destinations/impl/sink/__init__.py b/dlt/destinations/impl/destination/__init__.py similarity index 100% rename from dlt/destinations/impl/sink/__init__.py rename to dlt/destinations/impl/destination/__init__.py diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py new file mode 100644 index 0000000000..f3875b9cf5 --- /dev/null +++ b/dlt/destinations/impl/destination/configuration.py @@ -0,0 +1,31 @@ +from typing import TYPE_CHECKING, Optional, Final, Callable, Union, Any + +from dlt.common.configuration import configspec +from dlt.common.destination import TLoaderFileFormat +from dlt.common.destination.reference import ( + DestinationClientConfiguration, + CredentialsConfiguration, +) +from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema + + +TSinkCallable = Callable[[Union[TDataItems, str], TTableSchema], None] + + +@configspec +class SinkClientConfiguration(DestinationClientConfiguration): + destination_type: Final[str] = "sink" # type: ignore + destination_callable: Optional[str] = None # noqa: A003 + loader_file_format: TLoaderFileFormat = "puae-jsonl" + batch_size: int = 10 + + if TYPE_CHECKING: + + def __init__( + self, + *, + loader_file_format: TLoaderFileFormat = "puae-jsonl", + batch_size: int = 10, + destination_callable: Union[TSinkCallable, str] = None, + ) -> None: ... diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py new file mode 100644 index 0000000000..b72c21c370 --- /dev/null +++ b/dlt/destinations/impl/destination/factory.py @@ -0,0 +1,111 @@ +import typing as t +import inspect +from importlib import import_module + +from types import ModuleType +from dlt.common.typing import AnyFun + +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.configuration import known_sections, with_config, get_fun_spec +from dlt.common.configuration.exceptions import ConfigurationValueError + +from dlt.destinations.impl.destination.configuration import ( + SinkClientConfiguration, + TSinkCallable, +) +from dlt.destinations.impl.destination import capabilities +from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.utils import get_callable_name + +if t.TYPE_CHECKING: + from dlt.destinations.impl.destination.sink import SinkClient + + +class DestinationInfo(t.NamedTuple): + """Runtime information on a discovered destination""" + + SPEC: t.Type[SinkClientConfiguration] + f: AnyFun + module: ModuleType + + +_DESTINATIONS: t.Dict[str, DestinationInfo] = {} +"""A registry of all the decorated destinations""" + + +class destination(Destination[SinkClientConfiguration, "SinkClient"]): + def capabilities(self) -> DestinationCapabilitiesContext: + return capabilities( + self.config_params.get("loader_file_format", "puae-jsonl"), + self.config_params.get("naming_convention", "direct"), + ) + + @property + def spec(self) -> t.Type[SinkClientConfiguration]: + """A spec of destination configuration resolved from the sink function signature""" + return self._spec + + @property + def client_class(self) -> t.Type["SinkClient"]: + from dlt.destinations.impl.destination.sink import SinkClient + + return SinkClient + + def __init__( + self, + destination_callable: t.Union[TSinkCallable, str] = None, # noqa: A003 + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + loader_file_format: TLoaderFileFormat = None, + batch_size: int = 10, + naming_convention: str = "direct", + spec: t.Type[SinkClientConfiguration] = SinkClientConfiguration, + **kwargs: t.Any, + ) -> None: + # resolve callable + if callable(destination_callable): + pass + elif destination_callable: + try: + module_path, attr_name = destination_callable.rsplit(".", 1) + dest_module = import_module(module_path) + except ModuleNotFoundError as e: + raise ConfigurationValueError( + f"Could not find callable module at {module_path}" + ) from e + try: + destination_callable = getattr(dest_module, attr_name) + except AttributeError as e: + raise ConfigurationValueError( + f"Could not find callable function at {destination_callable}" + ) from e + + if not callable(destination_callable): + raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") + + # resolve destination name + if destination_name is None: + destination_name = get_callable_name(destination_callable) + func_module = inspect.getmodule(destination_callable) + + # build destination spec + destination_sections = (known_sections.DESTINATION, destination_name) + conf_callable = with_config( + destination_callable, sections=destination_sections, include_defaults=True, base=spec + ) + + # save destination in registry + resolved_spec = get_fun_spec(conf_callable) + _DESTINATIONS[callable.__qualname__] = DestinationInfo(resolved_spec, callable, func_module) + + # remember spec + self._spec = resolved_spec or spec + super().__init__( + destination_name=destination_name, + environment=environment, + loader_file_format=loader_file_format, + batch_size=batch_size, + naming_convention=naming_convention, + destination_callable=conf_callable, + **kwargs, + ) diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/destination/sink.py similarity index 87% rename from dlt/destinations/impl/sink/sink.py rename to dlt/destinations/impl/destination/sink.py index a13f55551b..630ce19794 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/destination/sink.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from types import TracebackType -from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable -from dlt.common.configuration import with_config, known_sections +from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable, NamedTuple, Dict +import threading from dlt.destinations.job_impl import EmptyLoadJob from dlt.common.typing import TDataItems @@ -21,11 +21,11 @@ JobClientBase, ) -from dlt.destinations.impl.sink import capabilities -from dlt.destinations.impl.sink.configuration import SinkClientConfiguration, TSinkCallable +from dlt.destinations.impl.destination import capabilities +from dlt.destinations.impl.destination.configuration import SinkClientConfiguration, TSinkCallable -class SinkLoadJob(LoadJob, ABC): +class DestinationLoadJob(LoadJob, ABC): def __init__( self, table: TTableSchema, @@ -33,14 +33,14 @@ def __init__( config: SinkClientConfiguration, schema: Schema, destination_state: Dict[str, int], - resolved_callable: TSinkCallable, + destination_callable: TSinkCallable, ) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) self._file_path = file_path self._config = config self._table = table self._schema = schema - self._resolved_callable = resolved_callable + self._callable = destination_callable self._state: TLoadJobState = "running" self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" @@ -71,7 +71,9 @@ def call_callable_with_items(self, items: TDataItems) -> None: if not items: return # call callable - self._resolved_callable(items, self._table) + print("Start callable " + threading.currentThread().getName()) + self._callable(items, self._table) + print("end callable ") def state(self) -> TLoadJobState: return self._state @@ -80,7 +82,7 @@ def exception(self) -> str: raise NotImplementedError() -class SinkParquetLoadJob(SinkLoadJob): +class SinkParquetLoadJob(DestinationLoadJob): def run(self, start_index: int) -> Iterable[TDataItems]: # stream items from dlt.common.libs.pyarrow import pyarrow @@ -99,7 +101,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: yield record_batch -class SinkJsonlLoadJob(SinkLoadJob): +class SinkJsonlLoadJob(DestinationLoadJob): def run(self, start_index: int) -> Iterable[TDataItems]: current_batch: TDataItems = [] @@ -127,12 +129,7 @@ class SinkClient(JobClientBase): def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: super().__init__(schema, config) self.config: SinkClientConfiguration = config - - # inject config values - self.resolved_callable = with_config( - self.config.credentials.resolved_callable, - sections=(known_sections.DESTINATION, self.config.destination_name), - ) + self.destination_callable = self.config.destination_callable def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: pass @@ -153,11 +150,11 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> load_state = destination_state() if file_path.endswith("parquet"): return SinkParquetLoadJob( - table, file_path, self.config, self.schema, load_state, self.resolved_callable + table, file_path, self.config, self.schema, load_state, self.destination_callable ) if file_path.endswith("jsonl"): return SinkJsonlLoadJob( - table, file_path, self.config, self.schema, load_state, self.resolved_callable + table, file_path, self.config, self.schema, load_state, self.destination_callable ) return None diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py deleted file mode 100644 index 1b1ed8f893..0000000000 --- a/dlt/destinations/impl/sink/configuration.py +++ /dev/null @@ -1,71 +0,0 @@ -from typing import TYPE_CHECKING, Optional, Final, Callable, Union, Any -from importlib import import_module - -from dlt.common.configuration import configspec -from dlt.common.destination import TLoaderFileFormat -from dlt.common.destination.reference import ( - DestinationClientConfiguration, - CredentialsConfiguration, -) -from dlt.common.typing import TDataItems -from dlt.common.schema import TTableSchema -from dlt.common.configuration.exceptions import ConfigurationValueError - - -TSinkCallable = Callable[[Union[TDataItems, str], TTableSchema], None] - - -@configspec -class SinkClientCredentials(CredentialsConfiguration): - callable: Optional[str] = None # noqa: A003 - - def parse_native_representation(self, native_value: Any) -> None: - # a callable was passed in - if callable(native_value): - self.resolved_callable: TSinkCallable = native_value - # a path to a callable was passed in - if isinstance(native_value, str): - self.callable = native_value - - def to_native_representation(self) -> Any: - return self.resolved_callable - - def on_resolved(self) -> None: - if self.callable: - try: - module_path, attr_name = self.callable.rsplit(".", 1) - dest_module = import_module(module_path) - except ModuleNotFoundError as e: - raise ConfigurationValueError( - f"Could not find callable module at {module_path}" - ) from e - try: - self.resolved_callable = getattr(dest_module, attr_name) - except AttributeError as e: - raise ConfigurationValueError( - f"Could not find callable function at {self.callable}" - ) from e - - if not hasattr(self, "resolved_callable"): - raise ConfigurationValueError("Please specify callable for sink destination.") - - if not callable(self.resolved_callable): - raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") - - -@configspec -class SinkClientConfiguration(DestinationClientConfiguration): - destination_type: Final[str] = "sink" # type: ignore - credentials: SinkClientCredentials = None - loader_file_format: TLoaderFileFormat = "puae-jsonl" - batch_size: int = 10 - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Union[SinkClientCredentials, TSinkCallable, str] = None, - loader_file_format: TLoaderFileFormat = "puae-jsonl", - batch_size: int = 10, - ) -> None: ... diff --git a/dlt/destinations/impl/sink/factory.py b/dlt/destinations/impl/sink/factory.py deleted file mode 100644 index 6b2e98271e..0000000000 --- a/dlt/destinations/impl/sink/factory.py +++ /dev/null @@ -1,50 +0,0 @@ -import typing as t - -from dlt.common.destination import Destination, DestinationCapabilitiesContext - -from dlt.destinations.impl.sink.configuration import ( - SinkClientConfiguration, - SinkClientCredentials, - TSinkCallable, -) -from dlt.destinations.impl.sink import capabilities -from dlt.common.data_writers import TLoaderFileFormat - -if t.TYPE_CHECKING: - from dlt.destinations.impl.sink.sink import SinkClient - - -class sink(Destination[SinkClientConfiguration, "SinkClient"]): - spec = SinkClientConfiguration - - def capabilities(self) -> DestinationCapabilitiesContext: - return capabilities( - self.config_params.get("loader_file_format", "puae-jsonl"), - self.config_params.get("naming_convention", "direct"), - ) - - @property - def client_class(self) -> t.Type["SinkClient"]: - from dlt.destinations.impl.sink.sink import SinkClient - - return SinkClient - - def __init__( - self, - credentials: t.Union[SinkClientCredentials, TSinkCallable] = None, - destination_name: t.Optional[str] = None, - environment: t.Optional[str] = None, - loader_file_format: TLoaderFileFormat = None, - batch_size: int = 10, - naming_convention: str = "direct", - **kwargs: t.Any, - ) -> None: - super().__init__( - credentials=credentials, - destination_name=destination_name, - environment=environment, - loader_file_format=loader_file_format, - batch_size=batch_size, - naming_convention=naming_convention, - **kwargs, - ) diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 2a602f1c2e..de8e648c6d 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -10,8 +10,10 @@ from dlt.common.schema import TTableSchema from dlt.common.data_writers.writers import TLoaderFileFormat from dlt.common.destination.reference import Destination -from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.common.utils import uniq_id +from dlt.common.exceptions import InvalidDestinationReference +from dlt.common.configuration.exceptions import ConfigFieldMissingException from tests.load.utils import ( TABLE_ROW_ALL_DATA_TYPES, @@ -140,19 +142,11 @@ def local_sink_func(items: TDataItems, table: TTableSchema) -> None: p.run([1, 2, 3], table_name="items") assert len(calls) == 1 - # test passing via credentials - calls = [] - p = dlt.pipeline( - "sink_test", destination="sink", credentials=local_sink_func, full_refresh=True - ) - p.run([1, 2, 3], table_name="items") - assert len(calls) == 1 - # test passing via from_reference calls = [] p = dlt.pipeline( "sink_test", - destination=Destination.from_reference("sink", credentials=local_sink_func), # type: ignore + destination=Destination.from_reference("destination", destination_callable=local_sink_func), # type: ignore full_refresh=True, ) p.run([1, 2, 3], table_name="items") @@ -163,23 +157,30 @@ def local_sink_func(items: TDataItems, table: TTableSchema) -> None: global_calls = [] p = dlt.pipeline( "sink_test", - destination="sink", - credentials="tests.load.sink.test_sink.global_sink_func", + destination=Destination.from_reference("destination", destination_callable="tests.load.sink.test_sink.global_sink_func"), # type: ignore full_refresh=True, ) p.run([1, 2, 3], table_name="items") assert len(global_calls) == 1 # pass None credentials reference - p = dlt.pipeline("sink_test", destination="sink", credentials=None, full_refresh=True) - with pytest.raises(ConfigurationValueError): + with pytest.raises(InvalidDestinationReference): + p = dlt.pipeline( + "sink_test", + destination=Destination.from_reference("destination", destination_callable=None), + full_refresh=True, + ) p.run([1, 2, 3], table_name="items") # pass invalid credentials module - p = dlt.pipeline( - "sink_test", destination="sink", credentials="does.not.exist.callable", full_refresh=True - ) - with pytest.raises(ConfigurationValueError): + with pytest.raises(InvalidDestinationReference): + p = dlt.pipeline( + "sink_test", + destination=Destination.from_reference( + "destination", destination_callable="does.not.exist" + ), + full_refresh=True, + ) p.run([1, 2, 3], table_name="items") @@ -353,7 +354,7 @@ def my_sink(file_path, table, my_val=dlt.config.value): assert my_val == "something" # if no value is present, it should raise - with pytest.raises(PipelineStepFailed) as exc: + with pytest.raises(ConfigFieldMissingException) as exc: dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( [1, 2, 3], table_name="items" ) @@ -377,7 +378,7 @@ def other_sink(file_path, table, my_val=dlt.config.value): assert my_val == "something" # if no value is present, it should raise - with pytest.raises(PipelineStepFailed): + with pytest.raises(ConfigFieldMissingException): dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( [1, 2, 3], table_name="items" ) diff --git a/tests/utils.py b/tests/utils.py index c86ae92a2b..c40e5fe56a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,11 +45,11 @@ "motherduck", "mssql", "qdrant", - "sink", + "destination", "synapse", "databricks", } -NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "sink"} +NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "destination"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS # exclude destination configs (for now used for athena and athena iceberg separation) From 2db3430a858dea9bd054152d42947f501ff89528 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 7 Mar 2024 11:42:43 +0100 Subject: [PATCH 35/35] add possibility to create a resolved partial --- dlt/common/configuration/container.py | 11 +- dlt/common/configuration/inject.py | 120 ++++++++++++------- dlt/destinations/impl/destination/factory.py | 5 +- dlt/destinations/impl/destination/sink.py | 6 +- 4 files changed, 91 insertions(+), 51 deletions(-) diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index 95a42e0087..b5c6e52894 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -87,13 +87,10 @@ def _thread_context( context = self.main_context else: # thread pool names used in dlt contain originating thread id. use this id over pool id - # print(threading.currentThread().getName()) - # if m := re.match(r"dlt-pool-(\d+)", threading.currentThread().getName()): - # thread_id = int(m.group(1)) - # print("MATCH") - # else: - - thread_id = threading.get_ident() + if m := re.match(r"dlt-pool-(\d+)-", threading.currentThread().getName()): + thread_id = int(m.group(1)) + else: + thread_id = threading.get_ident() # return main context for main thread if thread_id == Container._MAIN_THREAD_ID: diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index 9d862f8795..c38ccb0c23 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -1,4 +1,6 @@ import inspect +import threading + from functools import wraps from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload from inspect import Signature, Parameter @@ -72,7 +74,7 @@ def with_config( prefer_existing_sections: (bool, optional): When joining existing section context, the existing context will be preferred to the one in `sections`. Default: False auto_pipeline_section (bool, optional): If True, a top level pipeline section will be added if `pipeline_name` argument is present . Defaults to False. include_defaults (bool, optional): If True then arguments with default values will be included in synthesized spec. If False only the required arguments marked with `dlt.secrets.value` and `dlt.config.value` are included - + base (Type[BaseConfiguration], optional): A base class for synthesized spec. Defaults to BaseConfiguration. Returns: Callable[[TFun], TFun]: A decorated function """ @@ -110,51 +112,53 @@ def decorator(f: TFun) -> TFun: pipeline_name_arg = p pipeline_name_arg_default = None if p.default == Parameter.empty else p.default - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> Any: + def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration: + """Resolve arguments using the provided spec""" # bind parameters to signature - bound_args = sig.bind(*args, **kwargs) # for calls containing resolved spec in the kwargs, we do not need to resolve again config: BaseConfiguration = None - if _LAST_DLT_CONFIG in kwargs: - config = last_config(**kwargs) + + # if section derivation function was provided then call it + if section_f: + curr_sections: Tuple[str, ...] = (section_f(bound_args.arguments),) + # sections may be a string + elif isinstance(sections, str): + curr_sections = (sections,) else: - # if section derivation function was provided then call it - if section_f: - curr_sections: Tuple[str, ...] = (section_f(bound_args.arguments),) - # sections may be a string - elif isinstance(sections, str): - curr_sections = (sections,) - else: - curr_sections = sections - - # if one of arguments is spec the use it as initial value - if initial_config: - config = initial_config - elif spec_arg: - config = bound_args.arguments.get(spec_arg.name, None) - # resolve SPEC, also provide section_context with pipeline_name - if pipeline_name_arg: - curr_pipeline_name = bound_args.arguments.get( - pipeline_name_arg.name, pipeline_name_arg_default - ) - else: - curr_pipeline_name = None - section_context = ConfigSectionContext( - pipeline_name=curr_pipeline_name, - sections=curr_sections, - merge_style=sections_merge_style, + curr_sections = sections + + # if one of arguments is spec the use it as initial value + if initial_config: + config = initial_config + elif spec_arg: + config = bound_args.arguments.get(spec_arg.name, None) + # resolve SPEC, also provide section_context with pipeline_name + if pipeline_name_arg: + curr_pipeline_name = bound_args.arguments.get( + pipeline_name_arg.name, pipeline_name_arg_default ) - # this may be called from many threads so section_context is thread affine - with inject_section(section_context): - # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") - config = resolve_configuration( - config or SPEC(), - explicit_value=bound_args.arguments, - accept_partial=accept_partial, - ) - resolved_params = dict(config) + else: + curr_pipeline_name = None + section_context = ConfigSectionContext( + pipeline_name=curr_pipeline_name, + sections=curr_sections, + merge_style=sections_merge_style, + ) + + # this may be called from many threads so section_context is thread affine + with inject_section(section_context): + # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") + return resolve_configuration( + config or SPEC(), + explicit_value=bound_args.arguments, + accept_partial=accept_partial, + ) + + def update_bound_args( + bound_args: inspect.BoundArguments, config: BaseConfiguration, *args, **kwargs + ) -> None: # overwrite or add resolved params + resolved_params = dict(config) for p in sig.parameters.values(): if p.name in resolved_params: bound_args.arguments[p.name] = resolved_params.pop(p.name) @@ -168,12 +172,48 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: bound_args.arguments[kwargs_arg.name].update(resolved_params) bound_args.arguments[kwargs_arg.name][_LAST_DLT_CONFIG] = config bound_args.arguments[kwargs_arg.name][_ORIGINAL_ARGS] = (args, kwargs) + + def create_resolved_partial() -> Any: + # creates a pre-resolved partial of the decorated function + empty_bound_args = sig.bind_partial() + config = resolve_config(empty_bound_args) + + # TODO: do some checks, for example fail if there is a spec arg + + def creator(*args: Any, **kwargs: Any) -> Any: + nonlocal config + + # we can still overwrite the config + if _LAST_DLT_CONFIG in kwargs: + config = last_config(**kwargs) + + # call the function with the pre-resolved config + bound_args = sig.bind(*args, **kwargs) + update_bound_args(bound_args, config, args, kwargs) + return f(*bound_args.args, **bound_args.kwargs) + + return creator + + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> Any: + # Resolve config + config: BaseConfiguration = None + bound_args = sig.bind(*args, **kwargs) + if _LAST_DLT_CONFIG in kwargs: + config = last_config(**kwargs) + else: + config = resolve_config(bound_args) + # call the function with resolved config + update_bound_args(bound_args, config, args, kwargs) return f(*bound_args.args, **bound_args.kwargs) # register the spec for a wrapped function _FUNC_SPECS[id(_wrap)] = SPEC + # add a method to create a pre-resolved partial + _wrap.create_resolved_partial = create_resolved_partial + return _wrap # type: ignore # See if we're being called as @with_config or @with_config(). diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py index b72c21c370..78de651404 100644 --- a/dlt/destinations/impl/destination/factory.py +++ b/dlt/destinations/impl/destination/factory.py @@ -91,7 +91,10 @@ def __init__( # build destination spec destination_sections = (known_sections.DESTINATION, destination_name) conf_callable = with_config( - destination_callable, sections=destination_sections, include_defaults=True, base=spec + destination_callable, + sections=destination_sections, + include_defaults=True, + base=spec, ) # save destination in registry diff --git a/dlt/destinations/impl/destination/sink.py b/dlt/destinations/impl/destination/sink.py index 630ce19794..6aeb4dd65d 100644 --- a/dlt/destinations/impl/destination/sink.py +++ b/dlt/destinations/impl/destination/sink.py @@ -71,9 +71,7 @@ def call_callable_with_items(self, items: TDataItems) -> None: if not items: return # call callable - print("Start callable " + threading.currentThread().getName()) self._callable(items, self._table) - print("end callable ") def state(self) -> TLoadJobState: return self._state @@ -129,7 +127,9 @@ class SinkClient(JobClientBase): def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: super().__init__(schema, config) self.config: SinkClientConfiguration = config - self.destination_callable = self.config.destination_callable + + # we create pre_resolved callable here + self.destination_callable = self.config.destination_callable.create_resolved_partial() def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: pass