diff --git a/dlt/__init__.py b/dlt/__init__.py index e2a6b1a3a7..eee105e47e 100644 --- a/dlt/__init__.py +++ b/dlt/__init__.py @@ -29,6 +29,8 @@ from dlt import sources from dlt.extract.decorators import source, resource, transformer, defer +from dlt.destinations.decorators import destination + from dlt.pipeline import ( pipeline as _pipeline, run, @@ -62,6 +64,7 @@ "resource", "transformer", "defer", + "destination", "pipeline", "run", "attach", diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index ad20765489..b5c6e52894 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -91,6 +91,7 @@ def _thread_context( thread_id = int(m.group(1)) else: thread_id = threading.get_ident() + # return main context for main thread if thread_id == Container._MAIN_THREAD_ID: return self.main_context diff --git a/dlt/common/configuration/inject.py b/dlt/common/configuration/inject.py index a22f299ae8..c38ccb0c23 100644 --- a/dlt/common/configuration/inject.py +++ b/dlt/common/configuration/inject.py @@ -1,4 +1,6 @@ import inspect +import threading + from functools import wraps from typing import Callable, Dict, Type, Any, Optional, Tuple, TypeVar, overload from inspect import Signature, Parameter @@ -58,6 +60,7 @@ def with_config( include_defaults: bool = True, accept_partial: bool = False, initial_config: Optional[BaseConfiguration] = None, + base: Type[BaseConfiguration] = BaseConfiguration, ) -> Callable[[TFun], TFun]: """Injects values into decorated function arguments following the specification in `spec` or by deriving one from function's signature. @@ -71,10 +74,11 @@ def with_config( prefer_existing_sections: (bool, optional): When joining existing section context, the existing context will be preferred to the one in `sections`. Default: False auto_pipeline_section (bool, optional): If True, a top level pipeline section will be added if `pipeline_name` argument is present . Defaults to False. include_defaults (bool, optional): If True then arguments with default values will be included in synthesized spec. If False only the required arguments marked with `dlt.secrets.value` and `dlt.config.value` are included - + base (Type[BaseConfiguration], optional): A base class for synthesized spec. Defaults to BaseConfiguration. Returns: Callable[[TFun], TFun]: A decorated function """ + section_f: Callable[[StrAny], str] = None # section may be a function from function arguments to section if callable(sections): @@ -88,9 +92,8 @@ def decorator(f: TFun) -> TFun: ) spec_arg: Parameter = None pipeline_name_arg: Parameter = None - if spec is None: - SPEC = spec_from_signature(f, sig, include_defaults) + SPEC = spec_from_signature(f, sig, include_defaults, base=base) else: SPEC = spec @@ -109,51 +112,53 @@ def decorator(f: TFun) -> TFun: pipeline_name_arg = p pipeline_name_arg_default = None if p.default == Parameter.empty else p.default - @wraps(f) - def _wrap(*args: Any, **kwargs: Any) -> Any: + def resolve_config(bound_args: inspect.BoundArguments) -> BaseConfiguration: + """Resolve arguments using the provided spec""" # bind parameters to signature - bound_args = sig.bind(*args, **kwargs) # for calls containing resolved spec in the kwargs, we do not need to resolve again config: BaseConfiguration = None - if _LAST_DLT_CONFIG in kwargs: - config = last_config(**kwargs) + + # if section derivation function was provided then call it + if section_f: + curr_sections: Tuple[str, ...] = (section_f(bound_args.arguments),) + # sections may be a string + elif isinstance(sections, str): + curr_sections = (sections,) else: - # if section derivation function was provided then call it - if section_f: - curr_sections: Tuple[str, ...] = (section_f(bound_args.arguments),) - # sections may be a string - elif isinstance(sections, str): - curr_sections = (sections,) - else: - curr_sections = sections - - # if one of arguments is spec the use it as initial value - if initial_config: - config = initial_config - elif spec_arg: - config = bound_args.arguments.get(spec_arg.name, None) - # resolve SPEC, also provide section_context with pipeline_name - if pipeline_name_arg: - curr_pipeline_name = bound_args.arguments.get( - pipeline_name_arg.name, pipeline_name_arg_default - ) - else: - curr_pipeline_name = None - section_context = ConfigSectionContext( - pipeline_name=curr_pipeline_name, - sections=curr_sections, - merge_style=sections_merge_style, + curr_sections = sections + + # if one of arguments is spec the use it as initial value + if initial_config: + config = initial_config + elif spec_arg: + config = bound_args.arguments.get(spec_arg.name, None) + # resolve SPEC, also provide section_context with pipeline_name + if pipeline_name_arg: + curr_pipeline_name = bound_args.arguments.get( + pipeline_name_arg.name, pipeline_name_arg_default ) - # this may be called from many threads so section_context is thread affine - with inject_section(section_context): - # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") - config = resolve_configuration( - config or SPEC(), - explicit_value=bound_args.arguments, - accept_partial=accept_partial, - ) - resolved_params = dict(config) + else: + curr_pipeline_name = None + section_context = ConfigSectionContext( + pipeline_name=curr_pipeline_name, + sections=curr_sections, + merge_style=sections_merge_style, + ) + + # this may be called from many threads so section_context is thread affine + with inject_section(section_context): + # print(f"RESOLVE CONF in inject: {f.__name__}: {section_context.sections} vs {sections}") + return resolve_configuration( + config or SPEC(), + explicit_value=bound_args.arguments, + accept_partial=accept_partial, + ) + + def update_bound_args( + bound_args: inspect.BoundArguments, config: BaseConfiguration, *args, **kwargs + ) -> None: # overwrite or add resolved params + resolved_params = dict(config) for p in sig.parameters.values(): if p.name in resolved_params: bound_args.arguments[p.name] = resolved_params.pop(p.name) @@ -167,12 +172,48 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: bound_args.arguments[kwargs_arg.name].update(resolved_params) bound_args.arguments[kwargs_arg.name][_LAST_DLT_CONFIG] = config bound_args.arguments[kwargs_arg.name][_ORIGINAL_ARGS] = (args, kwargs) + + def create_resolved_partial() -> Any: + # creates a pre-resolved partial of the decorated function + empty_bound_args = sig.bind_partial() + config = resolve_config(empty_bound_args) + + # TODO: do some checks, for example fail if there is a spec arg + + def creator(*args: Any, **kwargs: Any) -> Any: + nonlocal config + + # we can still overwrite the config + if _LAST_DLT_CONFIG in kwargs: + config = last_config(**kwargs) + + # call the function with the pre-resolved config + bound_args = sig.bind(*args, **kwargs) + update_bound_args(bound_args, config, args, kwargs) + return f(*bound_args.args, **bound_args.kwargs) + + return creator + + @wraps(f) + def _wrap(*args: Any, **kwargs: Any) -> Any: + # Resolve config + config: BaseConfiguration = None + bound_args = sig.bind(*args, **kwargs) + if _LAST_DLT_CONFIG in kwargs: + config = last_config(**kwargs) + else: + config = resolve_config(bound_args) + # call the function with resolved config + update_bound_args(bound_args, config, args, kwargs) return f(*bound_args.args, **bound_args.kwargs) # register the spec for a wrapped function _FUNC_SPECS[id(_wrap)] = SPEC + # add a method to create a pre-resolved partial + _wrap.create_resolved_partial = create_resolved_partial + return _wrap # type: ignore # See if we're being called as @with_config or @with_config(). diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index 659b4951df..61a0aa1dbf 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -7,7 +7,7 @@ from enum import Enum from dlt.common import pendulum, json, Decimal, Wei -from dlt.common.json import custom_pua_remove +from dlt.common.json import custom_pua_remove, json from dlt.common.json._simplejson import custom_encode as json_custom_encode from dlt.common.arithmetics import InvalidOperation from dlt.common.data_types.typing import TDataType @@ -105,6 +105,14 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: return int(value.value) return value + if to_type == "complex": + # try to coerce from text + if from_type == "text": + try: + return json.loads(value) + except Exception: + raise ValueError(value) + if to_type == "text": if from_type == "complex": return complex_to_str(value) diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index a78a31fdf3..0f2500c2cd 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -19,7 +19,7 @@ ] ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) # file formats used internally by dlt -INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"puae-jsonl", "sql", "reference", "arrow"} +INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"sql", "reference", "arrow"} # file formats that may be chosen by the user EXTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = ( set(get_args(TLoaderFileFormat)) - INTERNAL_LOADER_FILE_FORMATS diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index df221ec703..3cbaafefbe 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -3,6 +3,7 @@ import datetime # noqa: 251 import humanize import contextlib + from typing import ( Any, Callable, @@ -40,11 +41,15 @@ from dlt.common.schema.typing import TColumnNames, TColumnSchema, TWriteDisposition, TSchemaContract from dlt.common.source import get_current_pipe_name from dlt.common.storages.load_storage import LoadPackageInfo +from dlt.common.storages.load_package import PackageStorage + from dlt.common.time import ensure_pendulum_datetime, precise_time from dlt.common.typing import DictStrAny, REPattern, StrAny, SupportsHumanize from dlt.common.jsonpath import delete_matches, TAnyJsonPath from dlt.common.data_writers.writers import DataWriterMetrics, TLoaderFileFormat from dlt.common.utils import RowCounts, merge_row_counts +from dlt.common.versioned_state import TVersionedState +from dlt.common.storages.load_package import TLoadPackageState class _StepInfo(NamedTuple): @@ -454,7 +459,7 @@ class TPipelineLocalState(TypedDict, total=False): """Hash of state that was recently synced with destination""" -class TPipelineState(TypedDict, total=False): +class TPipelineState(TVersionedState, total=False): """Schema for a pipeline state that is stored within the pipeline working directory""" pipeline_name: str @@ -469,9 +474,6 @@ class TPipelineState(TypedDict, total=False): staging_type: Optional[str] # properties starting with _ are not automatically applied to pipeline object when state is restored - _state_version: int - _version_hash: str - _state_engine_version: int _local: TPipelineLocalState """A section of state that is not synchronized with the destination and does not participate in change merging and version control""" diff --git a/dlt/common/reflection/spec.py b/dlt/common/reflection/spec.py index 0a486088c8..ffc12e908c 100644 --- a/dlt/common/reflection/spec.py +++ b/dlt/common/reflection/spec.py @@ -26,7 +26,10 @@ def _first_up(s: str) -> str: def spec_from_signature( - f: AnyFun, sig: Signature, include_defaults: bool = True + f: AnyFun, + sig: Signature, + include_defaults: bool = True, + base: Type[BaseConfiguration] = BaseConfiguration, ) -> Type[BaseConfiguration]: name = _get_spec_name_from_f(f) module = inspect.getmodule(f) @@ -109,7 +112,7 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType: # set annotations so they are present in __dict__ fields["__annotations__"] = annotations # synthesize type - T: Type[BaseConfiguration] = type(name, (BaseConfiguration,), fields) + T: Type[BaseConfiguration] = type(name, (base,), fields) SPEC = configspec()(T) # add to the module setattr(module, spec_id, SPEC) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 63409aa878..452a7a6443 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -1,6 +1,8 @@ import contextlib import os from copy import deepcopy +import threading + import datetime # noqa: 251 import humanize from pathlib import Path @@ -17,9 +19,19 @@ Set, get_args, cast, + Any, + Tuple, + TYPE_CHECKING, + TypedDict, ) from dlt.common import pendulum, json + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import ContainerInjectableContext +from dlt.common.configuration.exceptions import ContextDefaultCannotBeCreated +from dlt.common.configuration.container import Container + from dlt.common.data_writers import DataWriter, new_file_id from dlt.common.destination import TLoaderFileFormat from dlt.common.exceptions import TerminalValueError @@ -27,13 +39,72 @@ from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns from dlt.common.storages import FileStorage from dlt.common.storages.exceptions import LoadPackageNotFound -from dlt.common.typing import DictStrAny, StrAny, SupportsHumanize +from dlt.common.typing import DictStrAny, SupportsHumanize from dlt.common.utils import flatten_list_or_items +from dlt.common.versioned_state import ( + generate_state_version_hash, + bump_state_version_if_modified, + TVersionedState, + default_versioned_state, +) +from typing_extensions import NotRequired + + +class TLoadPackageState(TVersionedState, total=False): + created_at: str + """Timestamp when the loadpackage was created""" + + """A section of state that does not participate in change merging and version control""" + destination_state: NotRequired[Dict[str, Any]] + """private space for destinations to store state relevant only to the load package""" + + +class TLoadPackage(TypedDict, total=False): + load_id: str + """Load id""" + state: TLoadPackageState + """State of the load package""" + + +# allows to upgrade state when restored with a new version of state logic/schema +LOADPACKAGE_STATE_ENGINE_VERSION = 1 + + +def generate_loadpackage_state_version_hash(state: TLoadPackageState) -> str: + return generate_state_version_hash(state) + + +def bump_loadpackage_state_version_if_modified(state: TLoadPackageState) -> Tuple[int, str, str]: + return bump_state_version_if_modified(state) + + +def migrate_loadpackage_state( + state: DictStrAny, from_engine: int, to_engine: int +) -> TLoadPackageState: + # TODO: if you start adding new versions, we need proper tests for these migrations! + # NOTE: do not touch destinations state, it is not versioned + if from_engine == to_engine: + return cast(TLoadPackageState, state) + + # check state engine + if from_engine != to_engine: + raise Exception("No upgrade path for loadpackage state") + + state["_state_engine_version"] = from_engine + return cast(TLoadPackageState, state) + + +def default_loadpackage_state() -> TLoadPackageState: + return { + **default_versioned_state(), + "_state_engine_version": LOADPACKAGE_STATE_ENGINE_VERSION, + } + # folders to manage load jobs in a single load package TJobState = Literal["new_jobs", "failed_jobs", "started_jobs", "completed_jobs"] WORKING_FOLDERS: Set[TJobState] = set(get_args(TJobState)) -TLoadPackageState = Literal["new", "extracted", "normalized", "loaded", "aborted"] +TLoadPackageStatus = Literal["new", "extracted", "normalized", "loaded", "aborted"] class ParsedLoadJobFileName(NamedTuple): @@ -125,7 +196,7 @@ def __str__(self) -> str: class _LoadPackageInfo(NamedTuple): load_id: str package_path: str - state: TLoadPackageState + state: TLoadPackageStatus schema: Schema schema_update: TSchemaTables completed_at: datetime.datetime @@ -201,8 +272,11 @@ class PackageStorage: PACKAGE_COMPLETED_FILE_NAME = ( # completed package marker file, currently only to store data with os.stat "package_completed.json" ) + LOAD_PACKAGE_STATE_FILE_NAME = ( # internal state of the load package, will not be synced to the destination + "load_package_state.json" + ) - def __init__(self, storage: FileStorage, initial_state: TLoadPackageState) -> None: + def __init__(self, storage: FileStorage, initial_state: TLoadPackageStatus) -> None: """Creates storage that manages load packages with root at `storage` and initial package state `initial_state`""" self.storage = storage self.initial_state = initial_state @@ -334,8 +408,13 @@ def create_package(self, load_id: str) -> None: self.storage.create_folder(os.path.join(load_id, PackageStorage.COMPLETED_JOBS_FOLDER)) self.storage.create_folder(os.path.join(load_id, PackageStorage.FAILED_JOBS_FOLDER)) self.storage.create_folder(os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER)) + # ensure created timestamp is set in state when load package is created + state = self.get_load_package_state(load_id) + if not state.get("created_at"): + state["created_at"] = pendulum.now().to_iso8601_string() + self.save_load_package_state(load_id, state) - def complete_loading_package(self, load_id: str, load_state: TLoadPackageState) -> str: + def complete_loading_package(self, load_id: str, load_state: TLoadPackageStatus) -> str: """Completes loading the package by writing marker file with`package_state. Returns path to the completed package""" load_path = self.get_package_path(load_id) # save marker file @@ -381,6 +460,34 @@ def save_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> Non ) as f: json.dump(schema_update, f) + # + # Loadpackage state + # + def get_load_package_state(self, load_id: str) -> TLoadPackageState: + package_path = self.get_package_path(load_id) + if not self.storage.has_folder(package_path): + raise LoadPackageNotFound(load_id) + try: + state_dump = self.storage.load( + os.path.join(package_path, PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME) + ) + state = json.loads(state_dump) + return migrate_loadpackage_state( + state, state["_state_engine_version"], LOADPACKAGE_STATE_ENGINE_VERSION + ) + except FileNotFoundError: + return default_loadpackage_state() + + def save_load_package_state(self, load_id: str, state: TLoadPackageState) -> None: + package_path = self.get_package_path(load_id) + if not self.storage.has_folder(package_path): + raise LoadPackageNotFound(load_id) + bump_loadpackage_state_version_if_modified(state) + self.storage.save( + os.path.join(package_path, PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME), + json.dumps(state), + ) + # # Get package info # @@ -514,3 +621,59 @@ def filter_jobs_for_table( all_jobs: Iterable[LoadJobInfo], table_name: str ) -> Sequence[LoadJobInfo]: return [job for job in all_jobs if job.job_file_info.table_name == table_name] + + +@configspec +class LoadPackageStateInjectableContext(ContainerInjectableContext): + storage: PackageStorage + load_id: str + can_create_default: ClassVar[bool] = False + global_affinity: ClassVar[bool] = True + + def commit(self) -> None: + with self.state_save_lock: + self.storage.save_load_package_state(self.load_id, self.state) + + def on_resolved(self) -> None: + self.state_save_lock = threading.Lock() + self.state = self.storage.get_load_package_state(self.load_id) + + if TYPE_CHECKING: + + def __init__(self, load_id: str, storage: PackageStorage) -> None: ... + + +def load_package() -> TLoadPackage: + """Get full load package state present in current context. Across all threads this will be the same in memory dict.""" + container = Container() + # get injected state if present. injected load package state is typically "managed" so changes will be persisted + # if you need to save the load package state during a load, you need to call commit_load_package_state + try: + state_ctx = container[LoadPackageStateInjectableContext] + except ContextDefaultCannotBeCreated: + raise Exception("Load package state not available") + return TLoadPackage(state=state_ctx.state, load_id=state_ctx.load_id) + + +def commit_load_package_state() -> None: + """Commit load package state present in current context. This is thread safe.""" + container = Container() + try: + state_ctx = container[LoadPackageStateInjectableContext] + except ContextDefaultCannotBeCreated: + raise Exception("Load package state not available") + state_ctx.commit() + + +def destination_state() -> DictStrAny: + """Get segment of load package state that is specific to the current destination.""" + lp = load_package() + return lp["state"].setdefault("destination_state", {}) + + +def clear_destination_state(commit: bool = True) -> None: + """Clear segment of load package state that is specific to the current destination. Optionally commit to load package.""" + lp = load_package() + lp["state"].pop("destination_state", None) + if commit: + commit_load_package_state() diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index a83502cb9b..ffd55e7f29 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -1,6 +1,7 @@ from os.path import join from typing import Iterable, Optional, Sequence +from dlt.common.typing import DictStrAny from dlt.common import json from dlt.common.configuration import known_sections from dlt.common.configuration.inject import with_config @@ -18,6 +19,7 @@ PackageStorage, ParsedLoadJobFileName, TJobState, + TLoadPackageState, ) from dlt.common.storages.exceptions import JobWithUnsupportedWriterException, LoadPackageNotFound @@ -38,6 +40,11 @@ def __init__( supported_file_formats: Iterable[TLoaderFileFormat], config: LoadStorageConfiguration = config.value, ) -> None: + # puae-jsonl jobs have the extension .jsonl, so cater for this here + if supported_file_formats and "puae-jsonl" in supported_file_formats: + supported_file_formats = list(supported_file_formats) + supported_file_formats.append("jsonl") + if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset(supported_file_formats): raise TerminalValueError(supported_file_formats) if preferred_file_format and preferred_file_format not in supported_file_formats: @@ -79,7 +86,7 @@ def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> def list_new_jobs(self, load_id: str) -> Sequence[str]: """Lists all jobs in new jobs folder of normalized package storage and checks if file formats are supported""" new_jobs = self.normalized_packages.list_new_jobs(load_id) - # # make sure all jobs have supported writers + # make sure all jobs have supported writers wrong_job = next( ( j @@ -184,3 +191,10 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: return self.loaded_packages.get_load_package_info(load_id) except LoadPackageNotFound: return self.normalized_packages.get_load_package_info(load_id) + + def get_load_package_state(self, load_id: str) -> TLoadPackageState: + """Gets state of normlized or loaded package with given load_id, all jobs and their statuses.""" + try: + return self.loaded_packages.get_load_package_state(load_id) + except LoadPackageNotFound: + return self.normalized_packages.get_load_package_state(load_id) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 8a247c2021..2b90b7c088 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -51,7 +51,9 @@ def list_files_to_normalize_sorted(self) -> Sequence[str]: [ file for file in files - if not file.endswith(PackageStorage.SCHEMA_FILE_NAME) and os.path.isfile(file) + if not file.endswith(PackageStorage.SCHEMA_FILE_NAME) + and os.path.isfile(file) + and not file.endswith(PackageStorage.LOAD_PACKAGE_STATE_FILE_NAME) ] ) diff --git a/dlt/common/versioned_state.py b/dlt/common/versioned_state.py new file mode 100644 index 0000000000..a051a6660c --- /dev/null +++ b/dlt/common/versioned_state.py @@ -0,0 +1,45 @@ +import base64 +import hashlib +from copy import copy + +import datetime # noqa: 251 +from dlt.common import json +from typing import TypedDict, Dict, Any, List, Tuple, cast + + +class TVersionedState(TypedDict, total=False): + _state_version: int + _version_hash: str + _state_engine_version: int + + +def generate_state_version_hash(state: TVersionedState, exclude_attrs: List[str] = None) -> str: + # generates hash out of stored schema content, excluding hash itself, version and local state + state_copy = copy(state) + exclude_attrs = exclude_attrs or [] + exclude_attrs.extend(["_state_version", "_state_engine_version", "_version_hash"]) + for attr in exclude_attrs: + state_copy.pop(attr, None) # type: ignore + content = json.typed_dumpb(state_copy, sort_keys=True) # type: ignore + h = hashlib.sha3_256(content) + return base64.b64encode(h.digest()).decode("ascii") + + +def bump_state_version_if_modified( + state: TVersionedState, exclude_attrs: List[str] = None +) -> Tuple[int, str, str]: + """Bumps the `state` version and version hash if content modified, returns (new version, new hash, old hash) tuple""" + hash_ = generate_state_version_hash(state, exclude_attrs) + previous_hash = state.get("_version_hash") + if not previous_hash: + # if hash was not set, set it without bumping the version, that's the initial state + pass + elif hash_ != previous_hash: + state["_state_version"] += 1 + + state["_version_hash"] = hash_ + return state["_state_version"], hash_, previous_hash + + +def default_versioned_state() -> TVersionedState: + return {"_state_version": 0, "_state_engine_version": 1} diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index c0a0b419c1..4a10deffc0 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -10,6 +10,7 @@ from dlt.destinations.impl.qdrant.factory import qdrant from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate +from dlt.destinations.impl.destination.factory import destination from dlt.destinations.impl.synapse.factory import synapse from dlt.destinations.impl.databricks.factory import databricks @@ -29,4 +30,5 @@ "weaviate", "synapse", "databricks", + "destination", ] diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py new file mode 100644 index 0000000000..e6df02d077 --- /dev/null +++ b/dlt/destinations/decorators.py @@ -0,0 +1,26 @@ +from typing import Any, Type +from dlt.destinations.impl.destination.factory import destination as _destination +from dlt.destinations.impl.destination.configuration import SinkClientConfiguration, TSinkCallable +from dlt.common.destination import TDestinationReferenceArg +from dlt.common.destination import TLoaderFileFormat + + +def destination( + loader_file_format: TLoaderFileFormat = None, + batch_size: int = 10, + name: str = None, + naming_convention: str = "direct", + spec: Type[SinkClientConfiguration] = SinkClientConfiguration, +) -> Any: + def decorator(destination_callable: TSinkCallable) -> TDestinationReferenceArg: + # return destination instance + return _destination( + spec=spec, + destination_callable=destination_callable, + loader_file_format=loader_file_format, + batch_size=batch_size, + destination_name=name, + naming_convention=naming_convention, + ) + + return decorator diff --git a/dlt/destinations/impl/destination/__init__.py b/dlt/destinations/impl/destination/__init__.py new file mode 100644 index 0000000000..fbad2d570f --- /dev/null +++ b/dlt/destinations/impl/destination/__init__.py @@ -0,0 +1,14 @@ +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.data_writers import TLoaderFileFormat + + +def capabilities( + preferred_loader_file_format: TLoaderFileFormat = "puae-jsonl", + naming_convention: str = "direct", +) -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) + caps.supported_loader_file_formats = ["puae-jsonl", "parquet"] + caps.supports_ddl_transactions = False + caps.supports_transactions = False + caps.naming_convention = naming_convention + return caps diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py new file mode 100644 index 0000000000..f3875b9cf5 --- /dev/null +++ b/dlt/destinations/impl/destination/configuration.py @@ -0,0 +1,31 @@ +from typing import TYPE_CHECKING, Optional, Final, Callable, Union, Any + +from dlt.common.configuration import configspec +from dlt.common.destination import TLoaderFileFormat +from dlt.common.destination.reference import ( + DestinationClientConfiguration, + CredentialsConfiguration, +) +from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema + + +TSinkCallable = Callable[[Union[TDataItems, str], TTableSchema], None] + + +@configspec +class SinkClientConfiguration(DestinationClientConfiguration): + destination_type: Final[str] = "sink" # type: ignore + destination_callable: Optional[str] = None # noqa: A003 + loader_file_format: TLoaderFileFormat = "puae-jsonl" + batch_size: int = 10 + + if TYPE_CHECKING: + + def __init__( + self, + *, + loader_file_format: TLoaderFileFormat = "puae-jsonl", + batch_size: int = 10, + destination_callable: Union[TSinkCallable, str] = None, + ) -> None: ... diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py new file mode 100644 index 0000000000..78de651404 --- /dev/null +++ b/dlt/destinations/impl/destination/factory.py @@ -0,0 +1,114 @@ +import typing as t +import inspect +from importlib import import_module + +from types import ModuleType +from dlt.common.typing import AnyFun + +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.configuration import known_sections, with_config, get_fun_spec +from dlt.common.configuration.exceptions import ConfigurationValueError + +from dlt.destinations.impl.destination.configuration import ( + SinkClientConfiguration, + TSinkCallable, +) +from dlt.destinations.impl.destination import capabilities +from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.utils import get_callable_name + +if t.TYPE_CHECKING: + from dlt.destinations.impl.destination.sink import SinkClient + + +class DestinationInfo(t.NamedTuple): + """Runtime information on a discovered destination""" + + SPEC: t.Type[SinkClientConfiguration] + f: AnyFun + module: ModuleType + + +_DESTINATIONS: t.Dict[str, DestinationInfo] = {} +"""A registry of all the decorated destinations""" + + +class destination(Destination[SinkClientConfiguration, "SinkClient"]): + def capabilities(self) -> DestinationCapabilitiesContext: + return capabilities( + self.config_params.get("loader_file_format", "puae-jsonl"), + self.config_params.get("naming_convention", "direct"), + ) + + @property + def spec(self) -> t.Type[SinkClientConfiguration]: + """A spec of destination configuration resolved from the sink function signature""" + return self._spec + + @property + def client_class(self) -> t.Type["SinkClient"]: + from dlt.destinations.impl.destination.sink import SinkClient + + return SinkClient + + def __init__( + self, + destination_callable: t.Union[TSinkCallable, str] = None, # noqa: A003 + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + loader_file_format: TLoaderFileFormat = None, + batch_size: int = 10, + naming_convention: str = "direct", + spec: t.Type[SinkClientConfiguration] = SinkClientConfiguration, + **kwargs: t.Any, + ) -> None: + # resolve callable + if callable(destination_callable): + pass + elif destination_callable: + try: + module_path, attr_name = destination_callable.rsplit(".", 1) + dest_module = import_module(module_path) + except ModuleNotFoundError as e: + raise ConfigurationValueError( + f"Could not find callable module at {module_path}" + ) from e + try: + destination_callable = getattr(dest_module, attr_name) + except AttributeError as e: + raise ConfigurationValueError( + f"Could not find callable function at {destination_callable}" + ) from e + + if not callable(destination_callable): + raise ConfigurationValueError("Resolved Sink destination callable is not a callable.") + + # resolve destination name + if destination_name is None: + destination_name = get_callable_name(destination_callable) + func_module = inspect.getmodule(destination_callable) + + # build destination spec + destination_sections = (known_sections.DESTINATION, destination_name) + conf_callable = with_config( + destination_callable, + sections=destination_sections, + include_defaults=True, + base=spec, + ) + + # save destination in registry + resolved_spec = get_fun_spec(conf_callable) + _DESTINATIONS[callable.__qualname__] = DestinationInfo(resolved_spec, callable, func_module) + + # remember spec + self._spec = resolved_spec or spec + super().__init__( + destination_name=destination_name, + environment=environment, + loader_file_format=loader_file_format, + batch_size=batch_size, + naming_convention=naming_convention, + destination_callable=conf_callable, + **kwargs, + ) diff --git a/dlt/destinations/impl/destination/sink.py b/dlt/destinations/impl/destination/sink.py new file mode 100644 index 0000000000..6aeb4dd65d --- /dev/null +++ b/dlt/destinations/impl/destination/sink.py @@ -0,0 +1,172 @@ +from abc import ABC, abstractmethod +from types import TracebackType +from typing import ClassVar, Dict, Optional, Type, Iterable, Iterable, NamedTuple, Dict +import threading + +from dlt.destinations.job_impl import EmptyLoadJob +from dlt.common.typing import TDataItems +from dlt.common import json +from dlt.pipeline.current import ( + destination_state, + commit_load_package_state, +) + +from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema.typing import TTableSchema +from dlt.common.storages import FileStorage +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import ( + TLoadJobState, + LoadJob, + JobClientBase, +) + +from dlt.destinations.impl.destination import capabilities +from dlt.destinations.impl.destination.configuration import SinkClientConfiguration, TSinkCallable + + +class DestinationLoadJob(LoadJob, ABC): + def __init__( + self, + table: TTableSchema, + file_path: str, + config: SinkClientConfiguration, + schema: Schema, + destination_state: Dict[str, int], + destination_callable: TSinkCallable, + ) -> None: + super().__init__(FileStorage.get_file_name_from_file_path(file_path)) + self._file_path = file_path + self._config = config + self._table = table + self._schema = schema + self._callable = destination_callable + + self._state: TLoadJobState = "running" + self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" + try: + if self._config.batch_size == 0: + # on batch size zero we only call the callable with the filename + self.call_callable_with_items(self._file_path) + else: + current_index = destination_state.get(self._storage_id, 0) + for batch in self.run(current_index): + self.call_callable_with_items(batch) + current_index += len(batch) + destination_state[self._storage_id] = current_index + + self._state = "completed" + except Exception as e: + self._state = "retry" + raise e + finally: + # save progress + commit_load_package_state() + + @abstractmethod + def run(self, start_index: int) -> Iterable[TDataItems]: + pass + + def call_callable_with_items(self, items: TDataItems) -> None: + if not items: + return + # call callable + self._callable(items, self._table) + + def state(self) -> TLoadJobState: + return self._state + + def exception(self) -> str: + raise NotImplementedError() + + +class SinkParquetLoadJob(DestinationLoadJob): + def run(self, start_index: int) -> Iterable[TDataItems]: + # stream items + from dlt.common.libs.pyarrow import pyarrow + + # guard against changed batch size after restart of loadjob + assert ( + start_index % self._config.batch_size + ) == 0, "Batch size was changed during processing of one load package" + + start_batch = start_index / self._config.batch_size + with pyarrow.parquet.ParquetFile(self._file_path) as reader: + for record_batch in reader.iter_batches(batch_size=self._config.batch_size): + if start_batch > 0: + start_batch -= 1 + continue + yield record_batch + + +class SinkJsonlLoadJob(DestinationLoadJob): + def run(self, start_index: int) -> Iterable[TDataItems]: + current_batch: TDataItems = [] + + # stream items + with FileStorage.open_zipsafe_ro(self._file_path) as f: + encoded_json = json.typed_loads(f.read()) + + for item in encoded_json: + # find correct start position + if start_index > 0: + start_index -= 1 + continue + current_batch.append(item) + if len(current_batch) == self._config.batch_size: + yield current_batch + current_batch = [] + yield current_batch + + +class SinkClient(JobClientBase): + """Sink Client""" + + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def __init__(self, schema: Schema, config: SinkClientConfiguration) -> None: + super().__init__(schema, config) + self.config: SinkClientConfiguration = config + + # we create pre_resolved callable here + self.destination_callable = self.config.destination_callable.create_resolved_partial() + + def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: + pass + + def is_storage_initialized(self) -> bool: + return True + + def drop_storage(self) -> None: + pass + + def update_stored_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> Optional[TSchemaTables]: + return super().update_stored_schema(only_tables, expected_update) + + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + # save our state in destination name scope + load_state = destination_state() + if file_path.endswith("parquet"): + return SinkParquetLoadJob( + table, file_path, self.config, self.schema, load_state, self.destination_callable + ) + if file_path.endswith("jsonl"): + return SinkJsonlLoadJob( + table, file_path, self.config, self.schema, load_state, self.destination_callable + ) + return None + + def restore_file_load(self, file_path: str) -> LoadJob: + return EmptyLoadJob.from_file_path(file_path, "completed") + + def complete_load(self, load_id: str) -> None: ... + + def __enter__(self) -> "SinkClient": + return self + + def __exit__( + self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType + ) -> None: + pass diff --git a/dlt/helpers/streamlit_helper.py b/dlt/helpers/streamlit_helper.py index f6b2f3a62f..f9f318323b 100644 --- a/dlt/helpers/streamlit_helper.py +++ b/dlt/helpers/streamlit_helper.py @@ -12,7 +12,7 @@ from dlt.common.libs.pandas import pandas from dlt.pipeline import Pipeline from dlt.pipeline.exceptions import CannotRestorePipelineException, SqlClientNotAvailable -from dlt.pipeline.state_sync import load_state_from_destination +from dlt.pipeline.state_sync import load_pipeline_state_from_destination try: import streamlit as st @@ -190,7 +190,7 @@ def _query_data_live(query: str, schema_name: str = None) -> pandas.DataFrame: st.header("Pipeline state info") with pipeline.destination_client() as client: if isinstance(client, WithStateSync): - remote_state = load_state_from_destination(pipeline.pipeline_name, client) + remote_state = load_pipeline_state_from_destination(pipeline.pipeline_name, client) local_state = pipeline.state col1, col2 = st.columns(2) diff --git a/dlt/load/load.py b/dlt/load/load.py index 050e7bce67..a19ed1736e 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,4 +1,4 @@ -import contextlib +import contextlib, threading from functools import reduce import datetime # noqa: 251 from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable @@ -7,10 +7,17 @@ from dlt.common import sleep, logger from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.accessors import config -from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo -from dlt.common.schema.utils import get_top_level_table +from dlt.common.pipeline import ( + LoadInfo, + LoadMetrics, + SupportsPipeline, + WithStepInfo, +) +from dlt.common.schema.utils import get_child_tables, get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState +from dlt.common.storages.load_package import LoadPackageStateInjectableContext from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.runtime.logger import pretty_format_exception @@ -19,7 +26,10 @@ DestinationTerminalException, DestinationTransientException, ) +from dlt.common.configuration.container import Container + from dlt.common.schema import Schema, TSchemaTables + from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, @@ -34,6 +44,7 @@ SupportsStagingDestination, TDestination, ) +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.destinations.job_impl import EmptyLoadJob @@ -414,7 +425,7 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: failed_job.job_file_info.job_id(), failed_job.failed_message, ) - # possibly raise on too many retires + # possibly raise on too many retries if self.config.raise_on_max_retries: for new_job in package_info.jobs["new_jobs"]: r_c = new_job.job_file_info.retry_count @@ -452,12 +463,19 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: schema = self.load_storage.normalized_packages.load_schema(load_id) logger.info(f"Loaded schema name {schema.name} and version {schema.stored_version}") + container = Container() # get top load id and mark as being processed with self.collector(f"Load {schema.name} in {load_id}"): - # the same load id may be processed across multiple runs - if not self.current_load_id: - self._step_info_start_load_id(load_id) - self.load_single_package(load_id, schema) + with container.injectable_context( + LoadPackageStateInjectableContext( + storage=self.load_storage.normalized_packages, + load_id=load_id, + ) + ): + # the same load id may be processed across multiple runs + if not self.current_load_id: + self._step_info_start_load_id(load_id) + self.load_single_package(load_id, schema) return TRunMetrics(False, len(self.load_storage.list_normalized_packages())) diff --git a/dlt/pipeline/current.py b/dlt/pipeline/current.py index 7fdc0f095c..25fd398623 100644 --- a/dlt/pipeline/current.py +++ b/dlt/pipeline/current.py @@ -2,6 +2,13 @@ from dlt.common.pipeline import source_state as _state, resource_state, get_current_pipe_name from dlt.pipeline import pipeline as _pipeline +from dlt.extract.decorators import get_source_schema +from dlt.common.storages.load_package import ( + load_package, + commit_load_package_state, + destination_state, + clear_destination_state, +) from dlt.extract.decorators import get_source_schema, get_source pipeline = _pipeline diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 185a11962a..042a62e8fb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -47,7 +47,7 @@ ) from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound -from dlt.common.typing import DictStrStr, TFun, TSecretValue, is_optional_type +from dlt.common.typing import DictStrAny, TFun, TSecretValue, is_optional_type from dlt.common.runners import pool_runner as runner from dlt.common.storages import ( LiveSchemaStorage, @@ -126,15 +126,17 @@ ) from dlt.pipeline.typing import TPipelineStep from dlt.pipeline.state_sync import ( - STATE_ENGINE_VERSION, - bump_version_if_modified, - load_state_from_destination, - migrate_state, + PIPELINE_STATE_ENGINE_VERSION, + bump_pipeline_state_version_if_modified, + load_pipeline_state_from_destination, + migrate_pipeline_state, state_resource, json_encode_state, json_decode_state, + default_pipeline_state, ) from dlt.pipeline.warnings import credentials_argument_deprecated +from dlt.common.storages.load_package import TLoadPackageState def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]: @@ -143,6 +145,7 @@ def decorator(f: TFun) -> TFun: def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # activate pipeline so right state is always provided self.activate() + # backup and restore state should_extract_state = may_extract_state and self.config.restore_from_destination with self.managed_state(extract_state=should_extract_state) as state: @@ -263,7 +266,14 @@ class Pipeline(SupportsPipeline): STATE_FILE: ClassVar[str] = "state.json" STATE_PROPS: ClassVar[List[str]] = list( set(get_type_hints(TPipelineState).keys()) - - {"sources", "destination_type", "destination_name", "staging_type", "staging_name"} + - { + "sources", + "destination_type", + "destination_name", + "staging_type", + "staging_name", + "destinations", + } ) LOCAL_STATE_PROPS: ClassVar[List[str]] = list(get_type_hints(TPipelineLocalState).keys()) DEFAULT_DATASET_SUFFIX: ClassVar[str] = "_dataset" @@ -438,6 +448,7 @@ def normalize( """Normalizes the data prepared with `extract` method, infers the schema and creates load packages for the `load` method. Requires `destination` to be known.""" if is_interactive(): workers = 1 + if loader_file_format and loader_file_format in INTERNAL_LOADER_FILE_FORMATS: raise ValueError(f"{loader_file_format} is one of internal dlt file formats.") # check if any schema is present, if not then no data was extracted @@ -745,7 +756,7 @@ def sync_destination( # write the state back self._props_to_state(state) - bump_version_if_modified(state) + bump_pipeline_state_version_if_modified(state) self._save_state(state) except Exception as ex: raise PipelineStepFailed(self, "sync", None, ex, None) from ex @@ -845,6 +856,10 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: except LoadPackageNotFound: return self._get_normalize_storage().extracted_packages.get_load_package_info(load_id) + def get_load_package_state(self, load_id: str) -> TLoadPackageState: + """Returns information on extracted/normalized/completed package with given load_id, all jobs and their statuses.""" + return self._get_load_storage().get_load_package_state(load_id) + def list_failed_jobs_in_package(self, load_id: str) -> Sequence[LoadJobInfo]: """List all failed jobs and associated error messages for a specified `load_id`""" return self._get_load_storage().get_load_package_info(load_id).jobs.get("failed_jobs", []) @@ -1365,16 +1380,15 @@ def _get_step_info(self, step: WithStepInfo[TStepMetrics, TStepInfo]) -> TStepIn def _get_state(self) -> TPipelineState: try: state = json_decode_state(self._pipeline_storage.load(Pipeline.STATE_FILE)) - return migrate_state( - self.pipeline_name, state, state["_state_engine_version"], STATE_ENGINE_VERSION + return migrate_pipeline_state( + self.pipeline_name, + state, + state["_state_engine_version"], + PIPELINE_STATE_ENGINE_VERSION, ) except FileNotFoundError: # do not set the state hash, this will happen on first merge - return { - "_state_version": 0, - "_state_engine_version": STATE_ENGINE_VERSION, - "_local": {"first_run": True}, - } + return default_pipeline_state() # state["_version_hash"] = generate_version_hash(state) # return state @@ -1404,7 +1418,7 @@ def _restore_state_from_destination(self) -> Optional[TPipelineState]: schema = Schema(schema_name) with self._get_destination_clients(schema)[0] as job_client: if isinstance(job_client, WithStateSync): - state = load_state_from_destination(self.pipeline_name, job_client) + state = load_pipeline_state_from_destination(self.pipeline_name, job_client) if state is None: logger.info( "The state was not found in the destination" @@ -1538,7 +1552,7 @@ def _bump_version_and_extract_state( Storage will be created on demand. In that case the extracted package will be immediately committed. """ - _, hash_, _ = bump_version_if_modified(self._props_to_state(state)) + _, hash_, _ = bump_pipeline_state_version_if_modified(self._props_to_state(state)) should_extract = hash_ != state["_local"].get("_last_extracted_hash") if should_extract and extract_state: data = state_resource(state) diff --git a/dlt/pipeline/state_sync.py b/dlt/pipeline/state_sync.py index fa3939969b..8c72a218a4 100644 --- a/dlt/pipeline/state_sync.py +++ b/dlt/pipeline/state_sync.py @@ -1,25 +1,28 @@ -import base64 import binascii from copy import copy -import hashlib -from typing import Tuple, cast +from typing import Tuple, cast, List import pendulum import dlt from dlt.common import json -from dlt.common.pipeline import TPipelineState from dlt.common.typing import DictStrAny from dlt.common.schema.typing import STATE_TABLE_NAME, TTableSchemaColumns from dlt.common.destination.reference import WithStateSync, Destination from dlt.common.utils import compressed_b64decode, compressed_b64encode +from dlt.common.versioned_state import ( + generate_state_version_hash, + bump_state_version_if_modified, + default_versioned_state, +) +from dlt.common.pipeline import TPipelineState from dlt.extract import DltResource -from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException +from dlt.pipeline.exceptions import ( + PipelineStateEngineNoUpgradePathException, +) - -# allows to upgrade state when restored with a new version of state logic/schema -STATE_ENGINE_VERSION = 4 +PIPELINE_STATE_ENGINE_VERSION = 4 # state table columns STATE_TABLE_COLUMNS: TTableSchemaColumns = { @@ -57,59 +60,15 @@ def decompress_state(state_str: str) -> DictStrAny: return json.typed_loadb(state_bytes) # type: ignore[no-any-return] -def generate_version_hash(state: TPipelineState) -> str: - # generates hash out of stored schema content, excluding hash itself, version and local state - state_copy = copy(state) - state_copy.pop("_state_version", None) - state_copy.pop("_state_engine_version", None) - state_copy.pop("_version_hash", None) - state_copy.pop("_local", None) - content = json.typed_dumpb(state_copy, sort_keys=True) - h = hashlib.sha3_256(content) - return base64.b64encode(h.digest()).decode("ascii") - +def generate_pipeline_state_version_hash(state: TPipelineState) -> str: + return generate_state_version_hash(state, exclude_attrs=["_local"]) -def bump_version_if_modified(state: TPipelineState) -> Tuple[int, str, str]: - """Bumps the `state` version and version hash if content modified, returns (new version, new hash, old hash) tuple""" - hash_ = generate_version_hash(state) - previous_hash = state.get("_version_hash") - if not previous_hash: - # if hash was not set, set it without bumping the version, that's initial schema - pass - elif hash_ != previous_hash: - state["_state_version"] += 1 - state["_version_hash"] = hash_ - return state["_state_version"], hash_, previous_hash +def bump_pipeline_state_version_if_modified(state: TPipelineState) -> Tuple[int, str, str]: + return bump_state_version_if_modified(state, exclude_attrs=["_local"]) -def state_resource(state: TPipelineState) -> DltResource: - state = copy(state) - state.pop("_local") - state_str = compress_state(state) - state_doc = { - "version": state["_state_version"], - "engine_version": state["_state_engine_version"], - "pipeline_name": state["pipeline_name"], - "state": state_str, - "created_at": pendulum.now(), - "version_hash": state["_version_hash"], - } - return dlt.resource( - [state_doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS - ) - - -def load_state_from_destination(pipeline_name: str, client: WithStateSync) -> TPipelineState: - # NOTE: if dataset or table holding state does not exist, the sql_client will rise DestinationUndefinedEntity. caller must handle this - state = client.get_stored_state(pipeline_name) - if not state: - return None - s = decompress_state(state.state) - return migrate_state(pipeline_name, s, s["_state_engine_version"], STATE_ENGINE_VERSION) - - -def migrate_state( +def migrate_pipeline_state( pipeline_name: str, state: DictStrAny, from_engine: int, to_engine: int ) -> TPipelineState: if from_engine == to_engine: @@ -119,7 +78,7 @@ def migrate_state( from_engine = 2 if from_engine == 2 and to_engine > 2: # you may want to recompute hash - state["_version_hash"] = generate_version_hash(state) # type: ignore[arg-type] + state["_version_hash"] = generate_pipeline_state_version_hash(state) # type: ignore[arg-type] from_engine = 3 if from_engine == 3 and to_engine > 3: if state.get("destination"): @@ -139,3 +98,41 @@ def migrate_state( ) state["_state_engine_version"] = from_engine return cast(TPipelineState, state) + + +def state_resource(state: TPipelineState) -> DltResource: + state = copy(state) + state.pop("_local") + state_str = compress_state(state) + state_doc = { + "version": state["_state_version"], + "engine_version": state["_state_engine_version"], + "pipeline_name": state["pipeline_name"], + "state": state_str, + "created_at": pendulum.now(), + "version_hash": state["_version_hash"], + } + return dlt.resource( + [state_doc], name=STATE_TABLE_NAME, write_disposition="append", columns=STATE_TABLE_COLUMNS + ) + + +def load_pipeline_state_from_destination( + pipeline_name: str, client: WithStateSync +) -> TPipelineState: + # NOTE: if dataset or table holding state does not exist, the sql_client will rise DestinationUndefinedEntity. caller must handle this + state = client.get_stored_state(pipeline_name) + if not state: + return None + s = decompress_state(state.state) + return migrate_pipeline_state( + pipeline_name, s, s["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION + ) + + +def default_pipeline_state() -> TPipelineState: + return { + **default_versioned_state(), + "_state_engine_version": PIPELINE_STATE_ENGINE_VERSION, + "_local": {"first_run": True}, + } diff --git a/docs/website/docs/dlt-ecosystem/destinations/sink.md b/docs/website/docs/dlt-ecosystem/destinations/sink.md new file mode 100644 index 0000000000..577d79ee24 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/destinations/sink.md @@ -0,0 +1,146 @@ +--- +title: Destination Decorator / Reverse ETL +description: Sink function `dlt` destination for reverse ETL +keywords: [reverse etl, sink, function, decorator, destination] +--- + +# Destination decorator / Reverse ETL + +The dlt destination decorator allows you to receive all data passing through your pipeline in a simple function. This can be extremely useful for +reverse ETL, where you are pushing data back to an api. You can also use this for sending data to a queue or a simple database destination that is not +yet supported by dlt, be aware that you will have to manually handle your own migrations in this case. It will also allow you to simply get a path +to the files of your normalized data, so if you need direct access to parquet or jsonl files to copy them somewhere or push them to a database, +you can do this here too. + +## Install dlt for Sink / reverse ETL +** To install the DLT without additional dependencies ** +``` +pip install dlt +``` + +## Setup Guide +### 1. Initialize the dlt project + +Let's start by initializing a new dlt project as follows: + +```bash +dlt init chess sink +``` +> 💡 This command will initialize your pipeline with chess as the source and sink as the destination. + +The above command generates several files and directories, including `.dlt/secrets.toml`. + +### 2. Set up a destination function for your pipeline +The sink destination differs from other destinations in that you do not need to provide connection credentials, but rather you provide a function which +gets called for all items loaded during a pipeline run or load operation. For the chess example, you can add the following lines at the top of the file. +With the `@dlt.destination` decorator you can convert + +A very simple dlt pipeline that pushes a list of items into a sink function might look like this: + +```python +from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema + +@dlt.destination(batch_size=10) +def my_sink(items: TDataItems, table: TTableSchema) -> None: + print(table["name"]) + print(items) + +pipe = dlt.pipeline("sink_pipeline", destination=my_sink) +pipe.run([1, 2, 3], table_name="items") + +``` + +To enable this destination decorator in your chess example, replace the line `destination='sink'` with `destination=sink` (without the quotes) to directly reference +the sink from your pipeline constructor. Now you can run your pipeline and see the output of all the items coming from the chess pipeline to your console. + +:::tip +1. You can also remove the typing information (TDataItems and TTableSchema) from this example, typing generally are useful to know the shape of the incoming objects though. +2. There are a few other ways for declaring sink functions for your pipeline described below. +::: + +## destination decorator function and signature + +The full signature of the destination decorator plus its function is the following: + +```python +@dlt.destination(batch_size=10, loader_file_format="jsonl", name="my_sink", naming="direct") +def sink(items: TDataItems, table: TTableSchema) -> None: + ... +``` + +#### Decorator +* The `batch_size` parameter on the destination decorator defines how many items per function call are batched together and sent as an array. If you set a batch-size of `0`, +instead of passing in actual dataitems, you will receive one call per load job with the path of the file as the items argument. You can then open and process that file +in any way you like. +* The `loader_file_format` parameter on the destination decorator defines in which format files are stored in the load package before being sent to the sink function, +this can be `jsonl` or `parquet`. +* The `name` parameter on the destination decorator defines the name of the destination that get's created by the destination decorator. +* The `naming_convention` parameter on the destination decorator defines the name of the destination that gets created by the destination decorator. This controls +how table and column names are normalized. The default is `direct` which will keep all names the same. + +#### Sink function +* The `items` parameter on the sink function contains the items being sent into the sink function. +* The `table` parameter contains the schema table the current call belongs to including all table hints and columns. For example the table name can be access with table["name"]. Keep in mind that dlt also created special tables prefixed with `_dlt` which you may want to ignore when processing data. +* You can also add config values and secrets to the function arguments, see below! + + +## Adding config variables and secrets +The destination decorator supports settings and secrets variables. If you, for example, plan to connect to a service that requires an api secret or a login, you can do the following: + +```python +@dlt.destination(batch_size=10, loader_file_format="jsonl", name="my_sink") +def my_sink(items: TDataItems, table: TTableSchema, api_key: dlt.secrets.value) -> None: + ... +``` + + +## Sink destination state +The sink destination keeps a local record of how many DataItems were processed, so if you, for example, use the sink destination to push DataItems to a remote api, and this +api becomes unavailable during the load resulting in a failed dlt pipeline run, you can repeat the run of your pipeline at a later stage and the sink destination will continue +where it left of. For this reason it makes sense to choose a batch size that you can process in one transaction (say one api request or one database transaction) so that if this +request or transaction fail repeatedly you can repeat it at the next run without pushing duplicate data to your remote location. + + + +And add the api key to your toml: + +```toml +[destination.my_sink] +api_key="some secrets" +``` + + +## Concurrency +Calls to the sink function by default will be executed on multiple threads, so you need to make sure you are not using any non-thread-safe nonlocal or global variables from outside +your sink function. If, for whichever reason, you need to have all calls be executed from the same thread, you can set the `workers` config variable of the load step to 1. For performance +reasons we recommend to keep the multithreaded approach and make sure that you, for example, are using threadsafe connection pools to a remote database or queue. + +## Referencing the sink function +There are multiple ways to reference the sink function you want to use. These are: + +```python +# file my_pipeline.py + +@dlt.destination(batch_size=10) +def local_sink_func(items: TDataItems, table: TTableSchema) -> None: + ... + +# reference function directly +p = dlt.pipeline(name="my_pipe", destination=local_sink_func) + +# fully qualified string to function location (can be used from config.toml or env vars) +p = dlt.pipeline(name="my_pipe", destination="sink", credentials="my_pipeline.local_sink_func") + +# via destination reference +p = dlt.pipeline(name="my_pipe", destination=Destination.from_reference("sink", credentials=local_sink_func, environment="staging")) +``` + +## Write disposition + +The sink destination will forward all normalized DataItems encountered during a pipeline run to the sink function, so there is no notion of write dispositions for the sink. + +## Staging support + +The sink destination does not support staging files in remote locations before being called at this time. If you need this feature, please let us know. + diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md b/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md index 9d1a5a0a02..0a68a725ec 100644 --- a/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md +++ b/docs/website/docs/dlt-ecosystem/verified-sources/pipedrive.md @@ -265,7 +265,7 @@ verified source. ```python load_data = pipedrive_source() # calls the source function - load_info = pipeline.run(load_info) #runs the pipeline with selected source configuration + load_info = pipeline.run(load_data) #runs the pipeline with selected source configuration print(load_info) ``` diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 821a1affad..3765ae3d18 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -98,6 +98,7 @@ const sidebars = { 'dlt-ecosystem/destinations/motherduck', 'dlt-ecosystem/destinations/weaviate', 'dlt-ecosystem/destinations/qdrant', + 'dlt-ecosystem/destinations/sink', ] }, ], diff --git a/tests/cases.py b/tests/cases.py index 8653f999c6..85caec4b8d 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Sequence, Tuple, Literal +from typing import Dict, List, Any, Sequence, Tuple, Literal, Union import base64 from hexbytes import HexBytes from copy import deepcopy @@ -7,7 +7,7 @@ from dlt.common import Decimal, pendulum, json from dlt.common.data_types import TDataType -from dlt.common.typing import StrAny +from dlt.common.typing import StrAny, TDataItems from dlt.common.wei import Wei from dlt.common.time import ( ensure_pendulum_datetime, @@ -161,18 +161,23 @@ def table_update_and_row( def assert_all_data_types_row( - db_row: List[Any], + db_row: Union[List[Any], TDataItems], parse_complex_strings: bool = False, allow_base64_binary: bool = False, timestamp_precision: int = 6, schema: TTableSchemaColumns = None, + expect_filtered_null_columns=False, ) -> None: # content must equal # print(db_row) schema = schema or TABLE_UPDATE_COLUMNS_SCHEMA # Include only columns requested in schema - db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)} + if isinstance(db_row, dict): + db_mapping = db_row.copy() + else: + db_mapping = {col_name: db_row[i] for i, col_name in enumerate(schema)} + expected_rows = {key: value for key, value in TABLE_ROW_ALL_DATA_TYPES.items() if key in schema} # prepare date to be compared: convert into pendulum instance, adjust microsecond precision if "col4" in expected_rows: @@ -226,8 +231,16 @@ def assert_all_data_types_row( if "col11" in db_mapping: db_mapping["col11"] = db_mapping["col11"].isoformat() - for expected, actual in zip(expected_rows.values(), db_mapping.values()): - assert expected == actual + if expect_filtered_null_columns: + for key, expected in expected_rows.items(): + if expected is None: + assert db_mapping.get(key, None) is None + db_mapping[key] = None + + for key, expected in expected_rows.items(): + actual = db_mapping[key] + assert expected == actual, f"Expected {expected} but got {actual} for column {key}" + assert db_mapping == expected_rows diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index 922024a89b..34b62f9564 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -377,10 +377,16 @@ def test_coerce_type_complex() -> None: assert coerce_value("complex", "complex", v_list) == v_list assert coerce_value("text", "complex", v_dict) == json.dumps(v_dict) assert coerce_value("text", "complex", v_list) == json.dumps(v_list) + assert coerce_value("complex", "text", json.dumps(v_dict)) == v_dict + assert coerce_value("complex", "text", json.dumps(v_list)) == v_list + # all other coercions fail with pytest.raises(ValueError): coerce_value("binary", "complex", v_list) + with pytest.raises(ValueError): + coerce_value("complex", "text", "not a json string") + def test_coerce_type_complex_with_pua() -> None: v_dict = { @@ -395,6 +401,10 @@ def test_coerce_type_complex_with_pua() -> None: } assert coerce_value("complex", "complex", copy(v_dict)) == exp_v assert coerce_value("text", "complex", copy(v_dict)) == json.dumps(exp_v) + + # TODO: what to test for this case if at all? + # assert coerce_value("complex", "text", json.dumps(v_dict)) == exp_v + # also decode recursively custom_pua_decode_nested(v_dict) # restores datetime type diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index f671ddcf32..8d47c22615 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -9,6 +9,15 @@ from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage from tests.utils import autouse_test_storage +from dlt.common.pendulum import pendulum +from dlt.common.configuration.container import Container +from dlt.common.storages.load_package import ( + LoadPackageStateInjectableContext, + destination_state, + load_package, + commit_load_package_state, + clear_destination_state, +) def test_is_partially_loaded(load_storage: LoadStorage) -> None: @@ -57,6 +66,83 @@ def test_save_load_schema(load_storage: LoadStorage) -> None: assert schema.stored_version == schema_copy.stored_version +def test_create_and_update_loadpackage_state(load_storage: LoadStorage) -> None: + load_storage.new_packages.create_package("copy") + state = load_storage.new_packages.get_load_package_state("copy") + assert state["_state_version"] == 0 + assert state["_version_hash"] is not None + assert state["created_at"] is not None + old_state = state.copy() + + state["new_key"] = "new_value" # type: ignore + load_storage.new_packages.save_load_package_state("copy", state) + + state = load_storage.new_packages.get_load_package_state("copy") + assert state["new_key"] == "new_value" # type: ignore + assert state["_state_version"] == 1 + assert state["_version_hash"] != old_state["_version_hash"] + # created timestamp should be conserved + assert state["created_at"] == old_state["created_at"] + + # check timestamp + time = pendulum.parse(state["created_at"]) + now = pendulum.now() + assert (now - time).in_seconds() < 2 # type: ignore + + +def test_loadpackage_state_injectable_context(load_storage: LoadStorage) -> None: + load_storage.new_packages.create_package("copy") + + container = Container() + with container.injectable_context( + LoadPackageStateInjectableContext( + storage=load_storage.new_packages, + load_id="copy", + ) + ): + # test general load package state + injected_state = load_package() + assert injected_state["state"]["_state_version"] == 0 + injected_state["state"]["new_key"] = "new_value" # type: ignore + + # not persisted yet + assert load_storage.new_packages.get_load_package_state("copy").get("new_key") is None + # commit + commit_load_package_state() + + # now it should be persisted + assert ( + load_storage.new_packages.get_load_package_state("copy").get("new_key") == "new_value" + ) + assert load_storage.new_packages.get_load_package_state("copy").get("_state_version") == 1 + + # check that second injection is the same as first + second_injected_instance = load_package() + assert second_injected_instance == injected_state + + # check scoped destination states + assert ( + load_storage.new_packages.get_load_package_state("copy").get("destination_state") + is None + ) + dstate = destination_state() + dstate["new_key"] = "new_value" + commit_load_package_state() + assert load_storage.new_packages.get_load_package_state("copy").get( + "destination_state" + ) == {"new_key": "new_value"} + + # this also shows up on the previously injected state + assert injected_state["state"]["destination_state"]["new_key"] == "new_value" + + # clear destination state + clear_destination_state() + assert ( + load_storage.new_packages.get_load_package_state("copy").get("destination_state") + is None + ) + + def test_job_elapsed_time_seconds(load_storage: LoadStorage) -> None: load_id, fn = start_loading_file(load_storage, "test file") # type: ignore[arg-type] fp = load_storage.normalized_packages.storage.make_full_path( diff --git a/tests/common/test_versioned_state.py b/tests/common/test_versioned_state.py new file mode 100644 index 0000000000..e1f31a8a92 --- /dev/null +++ b/tests/common/test_versioned_state.py @@ -0,0 +1,43 @@ +from dlt.common.versioned_state import ( + generate_state_version_hash, + bump_state_version_if_modified, + default_versioned_state, +) + + +def test_versioned_state() -> None: + state = default_versioned_state() + assert state["_state_version"] == 0 + assert state["_state_engine_version"] == 1 + + # first hash_ generation does not change version, attrs are not modified + version, hash_, previous_hash = bump_state_version_if_modified(state) + assert version == 0 + assert hash_ is not None + assert previous_hash is None + assert state["_version_hash"] == hash_ + + # change attr, but exclude while generating + state["foo"] = "bar" # type: ignore + version, hash_, previous_hash = bump_state_version_if_modified(state, exclude_attrs=["foo"]) + assert version == 0 + assert hash_ == previous_hash + + # now don't exclude (remember old hash_ to compare return vars) + old_hash = state["_version_hash"] + version, hash_, previous_hash = bump_state_version_if_modified(state) + assert version == 1 + assert hash_ != previous_hash + assert old_hash != hash_ + assert previous_hash == old_hash + + # messing with state engine version will not change hash_ + state["_state_engine_version"] = 5 + version, hash_, previous_hash = bump_state_version_if_modified(state) + assert version == 1 + assert hash_ == previous_hash + + # make sure state object is not modified while bumping with no effect + old_state = state.copy() + version, hash_, previous_hash = bump_state_version_if_modified(state) + assert old_state == state diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index cd18454d7c..8614af4734 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -106,7 +106,9 @@ def assert_destination_state_loaded(pipeline: Pipeline) -> None: """Verify stored destination state matches the local pipeline state""" client: SqlJobClientBase with pipeline.destination_client() as client: # type: ignore[assignment] - destination_state = state_sync.load_state_from_destination(pipeline.pipeline_name, client) + destination_state = state_sync.load_pipeline_state_from_destination( + pipeline.pipeline_name, client + ) pipeline_state = dict(pipeline.state) del pipeline_state["_local"] assert pipeline_state == destination_state diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 5ef2206031..02da91cefe 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -13,7 +13,11 @@ from dlt.pipeline.exceptions import SqlClientNotAvailable from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.state_sync import STATE_TABLE_COLUMNS, load_state_from_destination, state_resource +from dlt.pipeline.state_sync import ( + STATE_TABLE_COLUMNS, + load_pipeline_state_from_destination, + state_resource, +) from dlt.destinations.job_client_impl import SqlJobClientBase from tests.utils import TEST_STORAGE_ROOT @@ -54,14 +58,14 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - job_client: SqlJobClientBase with p.destination_client(p.default_schema.name) as job_client: # type: ignore[assignment] with pytest.raises(DestinationUndefinedEntity): - load_state_from_destination(p.pipeline_name, job_client) + load_pipeline_state_from_destination(p.pipeline_name, job_client) # sync the schema p.sync_schema() exists, _ = job_client.get_storage_table(schema.version_table_name) assert exists is True # dataset exists, still no table with pytest.raises(DestinationUndefinedEntity): - load_state_from_destination(p.pipeline_name, job_client) + load_pipeline_state_from_destination(p.pipeline_name, job_client) initial_state = p._get_state() # now add table to schema and sync initial_state["_local"]["_last_extracted_at"] = pendulum.now() @@ -84,14 +88,14 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - exists, _ = job_client.get_storage_table(schema.state_table_name) assert exists is True # table is there but no state - assert load_state_from_destination(p.pipeline_name, job_client) is None + assert load_pipeline_state_from_destination(p.pipeline_name, job_client) is None # extract state with p.managed_state(extract_state=True): pass # just run the existing extract p.normalize(loader_file_format=destination_config.file_format) p.load() - stored_state = load_state_from_destination(p.pipeline_name, job_client) + stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) local_state = p._get_state() local_state.pop("_local") assert stored_state == local_state @@ -101,7 +105,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - managed_state["sources"] = {"source": dict(JSON_TYPED_DICT_DECODED)} p.normalize(loader_file_format=destination_config.file_format) p.load() - stored_state = load_state_from_destination(p.pipeline_name, job_client) + stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) assert stored_state["sources"] == {"source": JSON_TYPED_DICT_DECODED} local_state = p._get_state() local_state.pop("_local") @@ -116,7 +120,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.normalize(loader_file_format=destination_config.file_format) info = p.load() assert len(info.loads_ids) == 0 - new_stored_state = load_state_from_destination(p.pipeline_name, job_client) + new_stored_state = load_pipeline_state_from_destination(p.pipeline_name, job_client) # new state should not be stored assert new_stored_state == stored_state @@ -147,7 +151,7 @@ def test_restore_state_utils(destination_config: DestinationTestConfiguration) - p.normalize(loader_file_format=destination_config.file_format) info = p.load() assert len(info.loads_ids) == 1 - new_stored_state_2 = load_state_from_destination(p.pipeline_name, job_client) + new_stored_state_2 = load_pipeline_state_from_destination(p.pipeline_name, job_client) # the stored state changed to next version assert new_stored_state != new_stored_state_2 assert new_stored_state["_state_version"] + 1 == new_stored_state_2["_state_version"] @@ -405,7 +409,7 @@ def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = Fa job_client: SqlJobClientBase with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] # state without completed load id is not visible - state = load_state_from_destination(pipeline_name, job_client) + state = load_pipeline_state_from_destination(pipeline_name, job_client) assert state is None diff --git a/tests/load/sink/__init__.py b/tests/load/sink/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py new file mode 100644 index 0000000000..de8e648c6d --- /dev/null +++ b/tests/load/sink/test_sink.py @@ -0,0 +1,390 @@ +from typing import List, Tuple, Dict + +import dlt +import pytest +import pytest +import os + +from copy import deepcopy +from dlt.common.typing import TDataItems +from dlt.common.schema import TTableSchema +from dlt.common.data_writers.writers import TLoaderFileFormat +from dlt.common.destination.reference import Destination +from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.common.utils import uniq_id +from dlt.common.exceptions import InvalidDestinationReference +from dlt.common.configuration.exceptions import ConfigFieldMissingException + +from tests.load.utils import ( + TABLE_ROW_ALL_DATA_TYPES, + TABLE_UPDATE_COLUMNS_SCHEMA, + assert_all_data_types_row, +) + +SUPPORTED_LOADER_FORMATS = ["parquet", "puae-jsonl"] + + +def _run_through_sink( + items: TDataItems, + loader_file_format: TLoaderFileFormat, + columns=None, + filter_dlt_tables: bool = True, + batch_size: int = 10, +) -> List[Tuple[TDataItems, TTableSchema]]: + """ + runs a list of items through the sink destination and returns colleceted calls + """ + calls: List[Tuple[TDataItems, TTableSchema]] = [] + + @dlt.destination(loader_file_format=loader_file_format, batch_size=batch_size) + def test_sink(items: TDataItems, table: TTableSchema) -> None: + nonlocal calls + if table["name"].startswith("_dlt") and filter_dlt_tables: + return + # convert pyarrow table to dict list here to make tests more simple downstream + if loader_file_format == "parquet": + items = items.to_pylist() # type: ignore + calls.append((items, table)) + + @dlt.resource(columns=columns, table_name="items") + def items_resource() -> TDataItems: + nonlocal items + yield items + + p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + p.run([items_resource()]) + + return calls + + +@pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) +def test_all_datatypes(loader_file_format: TLoaderFileFormat) -> None: + data_types = deepcopy(TABLE_ROW_ALL_DATA_TYPES) + column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) + + sink_calls = _run_through_sink( + [data_types, data_types, data_types], + loader_file_format, + columns=column_schemas, + batch_size=1, + ) + + # inspect result + assert len(sink_calls) == 3 + + item = sink_calls[0][0][0] + + # filter out _dlt columns + item = {k: v for k, v in item.items() if not k.startswith("_dlt")} + + # null values are not emitted + data_types = {k: v for k, v in data_types.items() if v is not None} + + assert_all_data_types_row(item, expect_filtered_null_columns=True) + + +@pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) +@pytest.mark.parametrize("batch_size", [1, 10, 23]) +def test_batch_size(loader_file_format: TLoaderFileFormat, batch_size: int) -> None: + items = [{"id": i, "value": str(i)} for i in range(100)] + + sink_calls = _run_through_sink(items, loader_file_format, batch_size=batch_size) + + if batch_size == 1: + assert len(sink_calls) == 100 + # one item per call + assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items() + elif batch_size == 10: + assert len(sink_calls) == 10 + # ten items in first call + assert len(sink_calls[0][0]) == 10 + assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items() + elif batch_size == 23: + assert len(sink_calls) == 5 + # 23 items in first call + assert len(sink_calls[0][0]) == 23 + assert sink_calls[0][0][0].items() > {"id": 0, "value": "0"}.items() + + # check all items are present + all_items = set() + for call in sink_calls: + item = call[0] + for entry in item: + all_items.add(entry["value"]) + + assert len(all_items) == 100 + for i in range(100): + assert str(i) in all_items + + +global_calls: List[Tuple[TDataItems, TTableSchema]] = [] + + +def global_sink_func(items: TDataItems, table: TTableSchema) -> None: + global global_calls + if table["name"].startswith("_dlt"): + return + global_calls.append((items, table)) + + +def test_instantiation() -> None: + calls: List[Tuple[TDataItems, TTableSchema]] = [] + + def local_sink_func(items: TDataItems, table: TTableSchema) -> None: + nonlocal calls + if table["name"].startswith("_dlt"): + return + calls.append((items, table)) + + # test decorator + calls = [] + p = dlt.pipeline("sink_test", destination=dlt.destination()(local_sink_func), full_refresh=True) + p.run([1, 2, 3], table_name="items") + assert len(calls) == 1 + + # test passing via from_reference + calls = [] + p = dlt.pipeline( + "sink_test", + destination=Destination.from_reference("destination", destination_callable=local_sink_func), # type: ignore + full_refresh=True, + ) + p.run([1, 2, 3], table_name="items") + assert len(calls) == 1 + + # test passing string reference + global global_calls + global_calls = [] + p = dlt.pipeline( + "sink_test", + destination=Destination.from_reference("destination", destination_callable="tests.load.sink.test_sink.global_sink_func"), # type: ignore + full_refresh=True, + ) + p.run([1, 2, 3], table_name="items") + assert len(global_calls) == 1 + + # pass None credentials reference + with pytest.raises(InvalidDestinationReference): + p = dlt.pipeline( + "sink_test", + destination=Destination.from_reference("destination", destination_callable=None), + full_refresh=True, + ) + p.run([1, 2, 3], table_name="items") + + # pass invalid credentials module + with pytest.raises(InvalidDestinationReference): + p = dlt.pipeline( + "sink_test", + destination=Destination.from_reference( + "destination", destination_callable="does.not.exist" + ), + full_refresh=True, + ) + p.run([1, 2, 3], table_name="items") + + +@pytest.mark.parametrize("loader_file_format", SUPPORTED_LOADER_FORMATS) +@pytest.mark.parametrize("batch_size", [1, 10, 23]) +def test_batched_transactions(loader_file_format: TLoaderFileFormat, batch_size: int) -> None: + calls: Dict[str, List[TDataItems]] = {} + # provoke errors on resources + provoke_error: Dict[str, int] = {} + + @dlt.destination(loader_file_format=loader_file_format, batch_size=batch_size) + def test_sink(items: TDataItems, table: TTableSchema) -> None: + nonlocal calls + table_name = table["name"] + if table_name.startswith("_dlt"): + return + + # convert pyarrow table to dict list here to make tests more simple downstream + if loader_file_format == "parquet": + items = items.to_pylist() # type: ignore + + # provoke error if configured + if table_name in provoke_error: + for item in items: + if provoke_error[table_name] == item["id"]: + raise AssertionError("Oh no!") + + calls.setdefault(table_name, []).append(items) + + @dlt.resource() + def items() -> TDataItems: + for i in range(100): + yield {"id": i, "value": str(i)} + + @dlt.resource() + def items2() -> TDataItems: + for i in range(100): + yield {"id": i, "value": str(i)} + + def assert_items_in_range(c: List[TDataItems], start: int, end: int) -> None: + """ + Ensure all items where called and no duplicates are present + """ + collected_items = set() + for call in c: + for item in call: + assert item["value"] not in collected_items + collected_items.add(item["value"]) + assert len(collected_items) == end - start + for i in range(start, end): + assert str(i) in collected_items + + # no errors are set, all items should be processed + p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + load_id = p.run([items(), items2()]).loads_ids[0] + assert_items_in_range(calls["items"], 0, 100) + assert_items_in_range(calls["items2"], 0, 100) + + # destination state should have all items + destination_state = p.get_load_package_state(load_id)["destination_state"] + values = {k.split(".")[0]: v for k, v in destination_state.items()} + assert values == {"_dlt_pipeline_state": 1, "items": 100, "items2": 100} + + # provoke errors + calls = {} + provoke_error = {"items": 25, "items2": 45} + p = dlt.pipeline("sink_test", destination=test_sink, full_refresh=True) + with pytest.raises(PipelineStepFailed): + p.run([items(), items2()]) + + # we should have data for one load id saved here + load_id = p.list_normalized_load_packages()[0] + destination_state = p.get_load_package_state(load_id)["destination_state"] + + # get saved indexes mapped to table (this test will only work for one job per table) + values = {k.split(".")[0]: v for k, v in destination_state.items()} + + # partly loaded, pointers in state should be right + if batch_size == 1: + assert_items_in_range(calls["items"], 0, 25) + assert_items_in_range(calls["items2"], 0, 45) + # one pointer for state, one for items, one for items2... + assert values == {"_dlt_pipeline_state": 1, "items": 25, "items2": 45} + elif batch_size == 10: + assert_items_in_range(calls["items"], 0, 20) + assert_items_in_range(calls["items2"], 0, 40) + assert values == {"_dlt_pipeline_state": 1, "items": 20, "items2": 40} + elif batch_size == 23: + assert_items_in_range(calls["items"], 0, 23) + assert_items_in_range(calls["items2"], 0, 23) + assert values == {"_dlt_pipeline_state": 1, "items": 23, "items2": 23} + else: + raise AssertionError("Unknown batch size") + + # load the rest + first_calls = deepcopy(calls) + provoke_error = {} + calls = {} + p.load() + + # destination state should have all items + destination_state = p.get_load_package_state(load_id)["destination_state"] + values = {k.split(".")[0]: v for k, v in destination_state.items()} + assert values == {"_dlt_pipeline_state": 1, "items": 100, "items2": 100} + + # both calls combined should have every item called just once + assert_items_in_range(calls["items"] + first_calls["items"], 0, 100) + assert_items_in_range(calls["items2"] + first_calls["items2"], 0, 100) + + +def test_naming_convention() -> None: + @dlt.resource(table_name="PErson") + def resource(): + yield [{"UpperCase": 1, "snake_case": 1, "camelCase": 1}] + + # check snake case + @dlt.destination(naming_convention="snake_case") + def snake_sink(items, table): + if table["name"].startswith("_dlt"): + return + assert table["name"] == "p_erson" + assert table["columns"]["upper_case"]["name"] == "upper_case" + assert table["columns"]["snake_case"]["name"] == "snake_case" + assert table["columns"]["camel_case"]["name"] == "camel_case" + + dlt.pipeline("sink_test", destination=snake_sink, full_refresh=True).run(resource()) + + # check default (which is direct) + @dlt.destination() + def direct_sink(items, table): + if table["name"].startswith("_dlt"): + return + assert table["name"] == "PErson" + assert table["columns"]["UpperCase"]["name"] == "UpperCase" + assert table["columns"]["snake_case"]["name"] == "snake_case" + assert table["columns"]["camelCase"]["name"] == "camelCase" + + dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource()) + + +def test_file_batch() -> None: + @dlt.resource(table_name="person") + def resource1(): + for i in range(100): + yield [{"id": i, "name": f"Name {i}"}] + + @dlt.resource(table_name="address") + def resource2(): + for i in range(50): + yield [{"id": i, "city": f"City {i}"}] + + @dlt.destination(batch_size=0, loader_file_format="parquet") + def direct_sink(file_path, table): + if table["name"].startswith("_dlt"): + return + from dlt.common.libs.pyarrow import pyarrow + + assert table["name"] in ["person", "address"] + + with pyarrow.parquet.ParquetFile(file_path) as reader: + assert reader.metadata.num_rows == (100 if table["name"] == "person" else 50) + + dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( + [resource1(), resource2()] + ) + + +def test_config_spec() -> None: + @dlt.destination() + def my_sink(file_path, table, my_val=dlt.config.value): + assert my_val == "something" + + # if no value is present, it should raise + with pytest.raises(ConfigFieldMissingException) as exc: + dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # right value will pass + os.environ["DESTINATION__MY_SINK__MY_VAL"] = "something" + dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # wrong value will raise + os.environ["DESTINATION__MY_SINK__MY_VAL"] = "wrong" + with pytest.raises(PipelineStepFailed) as exc: + dlt.pipeline("sink_test", destination=my_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # will respect given name + @dlt.destination(name="some_name") + def other_sink(file_path, table, my_val=dlt.config.value): + assert my_val == "something" + + # if no value is present, it should raise + with pytest.raises(ConfigFieldMissingException): + dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) + + # right value will pass + os.environ["DESTINATION__SOME_NAME__MY_VAL"] = "something" + dlt.pipeline("sink_test", destination=other_sink, full_refresh=True).run( + [1, 2, 3], table_name="items" + ) diff --git a/tests/load/weaviate/utils.py b/tests/load/weaviate/utils.py index ed378191e6..1b2a74fcb8 100644 --- a/tests/load/weaviate/utils.py +++ b/tests/load/weaviate/utils.py @@ -79,6 +79,8 @@ def delete_classes(p, class_list): def drop_active_pipeline_data() -> None: def schema_has_classes(client): + if not hasattr(client, "db_client"): + return None schema = client.db_client.schema.get() return schema["classes"] diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 0cebeb2ff7..272f57d966 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1366,11 +1366,11 @@ def test_resource_state_name_not_normalized() -> None: pipeline.load() # get state from destination - from dlt.pipeline.state_sync import load_state_from_destination + from dlt.pipeline.state_sync import load_pipeline_state_from_destination client: WithStateSync with pipeline.destination_client() as client: # type: ignore[assignment] - state = load_state_from_destination(pipeline.pipeline_name, client) + state = load_pipeline_state_from_destination(pipeline.pipeline_name, client) assert "airtable_emojis" in state["sources"] assert state["sources"]["airtable_emojis"]["resources"] == {"🦚Peacock": {"🦚🦚🦚": "🦚"}} diff --git a/tests/pipeline/test_pipeline_state.py b/tests/pipeline/test_pipeline_state.py index ee788367e1..f0bcda2717 100644 --- a/tests/pipeline/test_pipeline_state.py +++ b/tests/pipeline/test_pipeline_state.py @@ -14,7 +14,11 @@ from dlt.pipeline.exceptions import PipelineStateEngineNoUpgradePathException, PipelineStepFailed from dlt.pipeline.pipeline import Pipeline -from dlt.pipeline.state_sync import generate_version_hash, migrate_state, STATE_ENGINE_VERSION +from dlt.pipeline.state_sync import ( + generate_pipeline_state_version_hash, + migrate_pipeline_state, + PIPELINE_STATE_ENGINE_VERSION, +) from tests.utils import test_storage from tests.pipeline.utils import json_case_path, load_json_case @@ -482,21 +486,21 @@ def transform(item): ) -def test_migrate_state(test_storage: FileStorage) -> None: +def test_migrate_pipeline_state(test_storage: FileStorage) -> None: # test generation of version hash on migration to v3 state_v1 = load_json_case("state/state.v1") - state = migrate_state("test_pipeline", state_v1, state_v1["_state_engine_version"], 3) + state = migrate_pipeline_state("test_pipeline", state_v1, state_v1["_state_engine_version"], 3) assert state["_state_engine_version"] == 3 assert "_local" in state assert "_version_hash" in state - assert state["_version_hash"] == generate_version_hash(state) + assert state["_version_hash"] == generate_pipeline_state_version_hash(state) # full migration state_v1 = load_json_case("state/state.v1") - state = migrate_state( - "test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION + state = migrate_pipeline_state( + "test_pipeline", state_v1, state_v1["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION ) - assert state["_state_engine_version"] == STATE_ENGINE_VERSION + assert state["_state_engine_version"] == PIPELINE_STATE_ENGINE_VERSION # check destination migration assert state["destination_name"] == "postgres" @@ -505,12 +509,15 @@ def test_migrate_state(test_storage: FileStorage) -> None: with pytest.raises(PipelineStateEngineNoUpgradePathException) as py_ex: state_v1 = load_json_case("state/state.v1") - migrate_state( - "test_pipeline", state_v1, state_v1["_state_engine_version"], STATE_ENGINE_VERSION + 1 + migrate_pipeline_state( + "test_pipeline", + state_v1, + state_v1["_state_engine_version"], + PIPELINE_STATE_ENGINE_VERSION + 1, ) assert py_ex.value.init_engine == state_v1["_state_engine_version"] - assert py_ex.value.from_engine == STATE_ENGINE_VERSION - assert py_ex.value.to_engine == STATE_ENGINE_VERSION + 1 + assert py_ex.value.from_engine == PIPELINE_STATE_ENGINE_VERSION + assert py_ex.value.to_engine == PIPELINE_STATE_ENGINE_VERSION + 1 # also test pipeline init where state is old test_storage.create_folder("debug_pipeline") @@ -522,7 +529,7 @@ def test_migrate_state(test_storage: FileStorage) -> None: assert p.dataset_name == "debug_pipeline_data" assert p.default_schema_name == "example_source" state = p.state - assert state["_version_hash"] == generate_version_hash(state) + assert state["_version_hash"] == generate_pipeline_state_version_hash(state) # specifically check destination v3 to v4 migration state_v3 = { @@ -530,8 +537,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: "staging": "dlt.destinations.filesystem", "_state_engine_version": 3, } - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert state_v3["destination_name"] == "redshift" assert state_v3["destination_type"] == "dlt.destinations.redshift" @@ -544,8 +551,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: "destination": "dlt.destinations.redshift", "_state_engine_version": 3, } - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert state_v3["destination_name"] == "redshift" assert state_v3["destination_type"] == "dlt.destinations.redshift" @@ -554,8 +561,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: assert "staging_type" not in state_v3 state_v3 = {"destination": None, "staging": None, "_state_engine_version": 3} - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert "destination_name" not in state_v3 assert "destination_type" not in state_v3 @@ -563,8 +570,8 @@ def test_migrate_state(test_storage: FileStorage) -> None: assert "staging_type" not in state_v3 state_v3 = {"_state_engine_version": 2} - migrate_state( - "test_pipeline", state_v3, state_v3["_state_engine_version"], STATE_ENGINE_VERSION # type: ignore + migrate_pipeline_state( + "test_pipeline", state_v3, state_v3["_state_engine_version"], PIPELINE_STATE_ENGINE_VERSION # type: ignore ) assert "destination_name" not in state_v3 assert "destination_type" not in state_v3 diff --git a/tests/utils.py b/tests/utils.py index dd03279def..924f44de73 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,10 +45,11 @@ "motherduck", "mssql", "qdrant", + "destination", "synapse", "databricks", } -NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant"} +NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant", "destination"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS # exclude destination configs (for now used for athena and athena iceberg separation)