From a52b45f57b78929bd7d2e1df00b317c5875715cd Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Sat, 6 Apr 2024 17:11:10 +0530 Subject: [PATCH] Use drop schema in init_client (TODO: error) --- dlt/extract/extract.py | 16 ++++- dlt/load/load.py | 122 ++++++++++++++++++++------------------- dlt/load/utils.py | 26 ++++++++- dlt/pipeline/drop.py | 16 +++++ dlt/pipeline/pipeline.py | 2 + 5 files changed, 120 insertions(+), 62 deletions(-) diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 18b2d4c32a..24eef3db2f 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -394,10 +394,22 @@ def extract( state_paths="*" if self.refresh == "drop_dataset" else [], ) _state.update(new_state) + drop_schema = source.schema.clone() + if drop_info["tables"]: + drop_tables = { + key: table + for key, table in source.schema.tables.items() + if table["name"] in drop_info["tables"] + or table["name"] in drop_schema.dlt_table_names() + } + + drop_schema.tables.clear() + drop_schema.tables.update(drop_tables) + load_package.state["drop_schema"] = drop_schema.to_dict() source.schema.tables.clear() source.schema.tables.update(new_schema.tables) - dropped_tables = load_package.state.setdefault("dropped_tables", []) - dropped_tables.extend(drop_info["tables"]) + # dropped_tables = load_package.state.setdefault("dropped_tables", []) + # dropped_tables.extend(drop_info["tables"]) # reset resource states, the `extracted` list contains all the explicit resources and all their parents for resource in source.resources.extracted.values(): diff --git a/dlt/load/load.py b/dlt/load/load.py index 156df57712..1f4efcc912 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -68,6 +68,7 @@ def __init__( config: LoaderConfiguration = config.value, initial_client_config: DestinationClientConfiguration = config.value, initial_staging_client_config: DestinationClientConfiguration = config.value, + refresh: Optional[TRefreshMode] = None, ) -> None: self.config = config self.collector = collector @@ -79,7 +80,7 @@ def __init__( self.pool = NullExecutor() self.load_storage: LoadStorage = self.create_storage(is_storage_owner) self._loaded_packages: List[LoadPackageInfo] = [] - self._refreshed_tables: Set[str] = set() + self.refresh = refresh super().__init__() def create_storage(self, is_storage_owner: bool) -> LoadStorage: @@ -344,71 +345,75 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) - def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[str], Set[str]]: - """When using refresh mode, drop tables if possible. - Returns a set of tables for main destination and staging destination - that could not be dropped and should be truncated instead - """ - # Exclude tables already dropped in the same load - drop_tables = set(dropped_tables) - self._refreshed_tables - if not drop_tables: - return set(), set() - # Clone schema and remove tables from it - dropped_schema = deepcopy(schema) - for table_name in drop_tables: - # pop not del: The table may not actually be in the schema if it's not being loaded again - dropped_schema.tables.pop(table_name, None) - dropped_schema._bump_version() - trunc_dest: Set[str] = set() - trunc_staging: Set[str] = set() - # Drop from destination and replace stored schema so tables will be re-created before load - with self.get_destination_client(dropped_schema) as job_client: - # TODO: SupportsSql mixin - if hasattr(job_client, "drop_tables"): - job_client.drop_tables(*drop_tables, replace_schema=True) - else: - # Tables need to be truncated instead of dropped - trunc_dest = drop_tables - - if self.staging_destination: - with self.get_staging_destination_client(dropped_schema) as staging_client: - if hasattr(staging_client, "drop_tables"): - staging_client.drop_tables(*drop_tables, replace_schema=True) - else: - trunc_staging = drop_tables - self._refreshed_tables.update(drop_tables) # Don't drop table again in same load - return trunc_dest, trunc_staging + # def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[str], Set[str]]: + # """When using refresh mode, drop tables if possible. + # Returns a set of tables for main destination and staging destination + # that could not be dropped and should be truncated instead + # """ + # # Exclude tables already dropped in the same load + # drop_tables = set(dropped_tables) - self._refreshed_tables + # if not drop_tables: + # return set(), set() + # # Clone schema and remove tables from it + # dropped_schema = deepcopy(schema) + # for table_name in drop_tables: + # # pop not del: The table may not actually be in the schema if it's not being loaded again + # dropped_schema.tables.pop(table_name, None) + # dropped_schema._bump_version() + # trunc_dest: Set[str] = set() + # trunc_staging: Set[str] = set() + # # Drop from destination and replace stored schema so tables will be re-created before load + # with self.get_destination_client(dropped_schema) as job_client: + # # TODO: SupportsSql mixin + # if hasattr(job_client, "drop_tables"): + # job_client.drop_tables(*drop_tables, replace_schema=True) + # else: + # # Tables need to be truncated instead of dropped + # trunc_dest = drop_tables + + # if self.staging_destination: + # with self.get_staging_destination_client(dropped_schema) as staging_client: + # if hasattr(staging_client, "drop_tables"): + # staging_client.drop_tables(*drop_tables, replace_schema=True) + # else: + # trunc_staging = drop_tables + # self._refreshed_tables.update(drop_tables) # Don't drop table again in same load + # return trunc_dest, trunc_staging def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) - dropped_tables = current_load_package()["state"].get("dropped_tables", []) + # dropped_tables = current_load_package()["state"].get("dropped_tables", []) # Drop tables before loading if refresh mode is set - truncate_dest, truncate_staging = self._refresh(dropped_tables, schema) + # truncate_dest, truncate_staging = self._refresh(dropped_tables, schema) + drop_schema_dict = current_load_package()["state"].get("drop_schema") + drop_schema = Schema.from_dict(drop_schema_dict) if drop_schema_dict else None + init_schema = drop_schema if drop_schema else schema # initialize analytical storage ie. create dataset required by passed schema - with self.get_destination_client(schema) as job_client: + with self.get_destination_client(init_schema) as job_client: if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: # init job client - def should_truncate(table: TTableSchema) -> bool: - # When destination doesn't support dropping refreshed tables (i.e. not SQL based) they should be truncated - return ( - job_client.should_truncate_table_before_load(table) - or table["name"] in truncate_dest - ) + # def should_truncate(table: TTableSchema) -> bool: + # # When destination doesn't support dropping refreshed tables (i.e. not SQL based) they should be truncated + # return ( + # job_client.should_truncate_table_before_load(table) + # or table["name"] in truncate_dest + # ) applied_update = init_client( job_client, - schema, + init_schema, new_jobs, expected_update, - # job_client.should_truncate_table_before_load, - should_truncate, + job_client.should_truncate_table_before_load, + # should_truncate, ( job_client.should_load_data_to_staging_dataset if isinstance(job_client, WithStagingDataset) else None ), + refresh=self.refresh, ) # init staging client @@ -418,23 +423,24 @@ def should_truncate(table: TTableSchema) -> bool: " implement SupportsStagingDestination" ) - def should_truncate_staging(table: TTableSchema) -> bool: - return ( - job_client.should_truncate_table_before_load_on_staging_destination( - table - ) - or table["name"] in truncate_staging - ) + # def should_truncate_staging(table: TTableSchema) -> bool: + # return ( + # job_client.should_truncate_table_before_load_on_staging_destination( + # table + # ) + # or table["name"] in truncate_staging + # ) - with self.get_staging_destination_client(schema) as staging_client: + with self.get_staging_destination_client(init_schema) as staging_client: init_client( staging_client, - schema, + init_schema, new_jobs, expected_update, - # job_client.should_truncate_table_before_load_on_staging_destination, - should_truncate_staging, + job_client.should_truncate_table_before_load_on_staging_destination, + # should_truncate_staging, job_client.should_load_data_to_staging_dataset_on_staging_destination, + refresh=self.refresh, ) self.load_storage.commit_schema_update(load_id, applied_update) diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 067ae33613..fe3629b432 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -1,4 +1,4 @@ -from typing import List, Set, Iterable, Callable +from typing import List, Set, Iterable, Callable, Optional from dlt.common import logger from dlt.common.storages.load_package import LoadJobInfo, PackageStorage @@ -15,6 +15,7 @@ JobClientBase, WithStagingDataset, ) +from dlt.common.pipeline import TRefreshMode def get_completed_table_chain( @@ -66,6 +67,7 @@ def init_client( expected_update: TSchemaTables, truncate_filter: Callable[[TTableSchema], bool], load_staging_filter: Callable[[TTableSchema], bool], + refresh: Optional[TRefreshMode] = None, ) -> TSchemaTables: """Initializes destination storage including staging dataset if supported @@ -84,6 +86,8 @@ def init_client( """ # get dlt/internal tables dlt_tables = set(schema.dlt_table_names()) + + all_tables = set(schema.tables.keys()) # tables without data (TODO: normalizer removes such jobs, write tests and remove the line below) tables_no_data = set( table["name"] for table in schema.data_tables() if not has_table_seen_data(table) @@ -92,12 +96,23 @@ def init_client( tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data # get tables to truncate by extending tables with jobs with all their child tables + if refresh == "drop_data": + truncate_filter = lambda t: True truncate_tables = set( _extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter) ) + if refresh in ("drop_dataset", "drop_tables"): + drop_tables = all_tables - dlt_tables - tables_no_data + else: + drop_tables = set() + applied_update = _init_dataset_and_update_schema( - job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables + job_client, + expected_update, + tables_with_jobs | dlt_tables, + truncate_tables, + drop_tables=drop_tables, ) # update the staging dataset if client supports this @@ -128,6 +143,7 @@ def _init_dataset_and_update_schema( update_tables: Iterable[str], truncate_tables: Iterable[str] = None, staging_info: bool = False, + drop_tables: Optional[Iterable[str]] = None, ) -> TSchemaTables: staging_text = "for staging dataset" if staging_info else "" logger.info( @@ -146,6 +162,12 @@ def _init_dataset_and_update_schema( f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" ) job_client.initialize_storage(truncate_tables=truncate_tables) + if drop_tables: + if hasattr(job_client, "drop_tables"): + logger.info( + f"Client for {job_client.config.destination_type} will drop tables {staging_text}" + ) + job_client.drop_tables(*drop_tables) return applied_update diff --git a/dlt/pipeline/drop.py b/dlt/pipeline/drop.py index 2b6da13a6a..7a64bb14b9 100644 --- a/dlt/pipeline/drop.py +++ b/dlt/pipeline/drop.py @@ -69,6 +69,22 @@ def drop_resources( drop_all: bool = False, state_only: bool = False, ) -> Tuple[Schema, TPipelineState, _DropInfo]: + """Generate a new schema and pipeline state with the requested resources removed. + + Args: + schema: The schema to modify. + state: The pipeline state to modify. + resources: Resource name(s) or regex pattern(s) matching resource names to drop. + If empty, no resources will be dropped unless `drop_all` is True. + state_paths: JSON path(s) relative to the source state to drop. + drop_all: If True, all resources will be dropped (supeseeds `resources`). + state_only: If True, only modify the pipeline state, not schema + + Returns: + A 3 part tuple containing the new schema, the new pipeline state, and a dictionary + containing information about the drop operation. + """ + if isinstance(resources, str): resources = [resources] resources = list(resources) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index e7a1cd1992..413c984f44 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -540,6 +540,7 @@ def load( config=load_config, initial_client_config=client.config, initial_staging_client_config=staging_client.config if staging_client else None, + refresh=self.refresh if not self.first_run else None, ) try: with signals.delayed_signals(): @@ -548,6 +549,7 @@ def load( self.first_run = False return info except Exception as l_ex: + raise step_info = self._get_step_info(load_step) raise PipelineStepFailed( self, "load", load_step.current_load_id, l_ex, step_info