Skip to content

Commit

Permalink
Use load_package_state instead of drop tables file
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Apr 6, 2024
1 parent dae5c37 commit 187d1b0
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 55 deletions.
17 changes: 2 additions & 15 deletions dlt/common/storages/load_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ class PackageStorage:
LOAD_PACKAGE_STATE_FILE_NAME = ( # internal state of the load package, will not be synced to the destination
"load_package_state.json"
)
DROPPED_TABLES_FILE_NAME = "dropped_tables.json"

def __init__(self, storage: FileStorage, initial_state: TLoadPackageStatus) -> None:
"""Creates storage that manages load packages with root at `storage` and initial package state `initial_state`"""
Expand Down Expand Up @@ -410,6 +409,7 @@ def create_package(self, load_id: str) -> None:
self.storage.create_folder(os.path.join(load_id, PackageStorage.FAILED_JOBS_FOLDER))
self.storage.create_folder(os.path.join(load_id, PackageStorage.STARTED_JOBS_FOLDER))
# ensure created timestamp is set in state when load package is created

state = self.get_load_package_state(load_id)
if not state.get("created_at"):
state["created_at"] = pendulum.now().to_iso8601_string()
Expand Down Expand Up @@ -447,6 +447,7 @@ def load_schema(self, load_id: str) -> Schema:

def schema_name(self, load_id: str) -> str:
"""Gets schema name associated with the package"""

schema_dict: TStoredSchema = self._load_schema(load_id) # type: ignore[assignment]
return schema_dict["name"]

Expand All @@ -461,20 +462,6 @@ def save_schema_updates(self, load_id: str, schema_update: TSchemaTables) -> Non
) as f:
json.dump(schema_update, f)

def save_dropped_tables(self, load_id: str, dropped_tables: Sequence[str]) -> None:
with self.storage.open_file(
os.path.join(load_id, PackageStorage.DROPPED_TABLES_FILE_NAME), mode="wb"
) as f:
json.dump(dropped_tables, f)

def load_dropped_tables(self, load_id: str) -> List[str]:
try:
return json.loads( # type: ignore[no-any-return]
self.storage.load(os.path.join(load_id, PackageStorage.DROPPED_TABLES_FILE_NAME))
)
except FileNotFoundError:
return []

#
# Loadpackage state
#
Expand Down
21 changes: 17 additions & 4 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
TWriteDisposition,
)
from dlt.common.storages import NormalizeStorageConfiguration, LoadPackageInfo, SchemaStorage
from dlt.common.storages.load_package import ParsedLoadJobFileName
from dlt.common.storages.load_package import (
ParsedLoadJobFileName,
LoadPackageStateInjectableContext,
commit_load_package_state,
)
from dlt.common.utils import get_callable_name, get_full_class_name

from dlt.extract.decorators import SourceInjectableContext, SourceSchemaInjectableContext
Expand Down Expand Up @@ -362,7 +366,13 @@ def extract(
load_id = self.extract_storage.create_load_package(source.discover_schema())
with Container().injectable_context(
SourceSchemaInjectableContext(source.schema)
), Container().injectable_context(SourceInjectableContext(source)):
), Container().injectable_context(
SourceInjectableContext(source)
), Container().injectable_context(
LoadPackageStateInjectableContext(
load_id=load_id, storage=self.extract_storage.new_packages
)
) as load_package:
# inject the config section with the current source name
with inject_section(
ConfigSectionContext(
Expand All @@ -372,7 +382,7 @@ def extract(
):
if self.refresh is not None:
_resources_to_drop = (
list(source.resources.extracted) if self.refresh != "drop_dataset" else None
list(source.resources.extracted) if self.refresh != "drop_dataset" else []
)
_state, _ = pipeline_state(Container())
new_schema, new_state, drop_info = drop_resources(
Expand All @@ -381,11 +391,13 @@ def extract(
resources=_resources_to_drop,
drop_all=self.refresh == "drop_dataset",
state_only=self.refresh == "drop_data",
state_paths="*" if self.refresh == "drop_data" else [],
state_paths="*" if self.refresh == "drop_dataset" else [],
)
_state.update(new_state)
source.schema.tables.clear()
source.schema.tables.update(new_schema.tables)
dropped_tables = load_package.state.setdefault("dropped_tables", [])
dropped_tables.extend(drop_info["tables"])

# reset resource states, the `extracted` list contains all the explicit resources and all their parents
for resource in source.resources.extracted.values():
Expand All @@ -399,6 +411,7 @@ def extract(
max_parallel_items=max_parallel_items,
workers=workers,
)
commit_load_package_state()
return load_id

def commit_packages(self) -> None:
Expand Down
9 changes: 6 additions & 3 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from dlt.common.schema.utils import get_top_level_table
from dlt.common.schema.typing import TTableSchema
from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState
from dlt.common.storages.load_package import LoadPackageStateInjectableContext
from dlt.common.storages.load_package import (
LoadPackageStateInjectableContext,
load_package as current_load_package,
)
from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor
from dlt.common.runtime.collector import Collector, NULL_COLLECTOR
from dlt.common.logger import pretty_format_exception
Expand Down Expand Up @@ -355,7 +358,7 @@ def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[s
for table_name in drop_tables:
# pop not del: The table may not actually be in the schema if it's not being loaded again
dropped_schema.tables.pop(table_name, None)
dropped_schema.bump_version()
dropped_schema._bump_version()
trunc_dest: Set[str] = set()
trunc_staging: Set[str] = set()
# Drop from destination and replace stored schema so tables will be re-created before load
Expand All @@ -379,7 +382,7 @@ def _refresh(self, dropped_tables: Sequence[str], schema: Schema) -> Tuple[Set[s
def load_single_package(self, load_id: str, schema: Schema) -> None:
new_jobs = self.get_new_jobs_info(load_id)

dropped_tables = self.load_storage.normalized_packages.load_dropped_tables(load_id)
dropped_tables = current_load_package()["state"].get("dropped_tables", [])
# Drop tables before loading if refresh mode is set
truncate_dest, truncate_staging = self._refresh(dropped_tables, schema)

Expand Down
7 changes: 5 additions & 2 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,11 @@ def spool_files(
self.load_storage.new_packages.save_schema_updates(
load_id, merge_schema_updates(schema_updates)
)
self.load_storage.new_packages.save_dropped_tables(
load_id, self.normalize_storage.extracted_packages.load_dropped_tables(load_id)
# self.load_storage.new_packages.save_dropped_tables(
# load_id, self.normalize_storage.extracted_packages.load_dropped_tables(load_id)
# )
self.load_storage.new_packages.save_load_package_state(
load_id, self.normalize_storage.extracted_packages.get_load_package_state(load_id)
)
# files must be renamed and deleted together so do not attempt that when process is about to be terminated
signals.raise_if_signalled()
Expand Down
31 changes: 1 addition & 30 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,36 +430,7 @@ def extract(
):
if source.exhausted:
raise SourceExhausted(source.name)
# dropped_tables = []
# if not self.first_run:
# if self.refresh == "full":
# # Drop all tables from all resources and all source state paths
# d = DropCommand(
# self,
# drop_all=True,
# extract_only=True, # Only modify local state/schema, destination drop tables is done in load step
# state_paths="*",
# schema_name=source.schema.name,
# )
# dropped_tables = d.info["tables"]
# d()
# elif self.refresh == "replace":
# # Drop tables from resources being currently extracted
# d = DropCommand(
# self,
# resources=source.resources.extracted,
# extract_only=True,
# schema_name=source.schema.name,
# )
# dropped_tables = d.info["tables"]
# d()
load_id = self._extract_source(
extract_step, source, max_parallel_items, workers
)
# Save the tables that were dropped locally (to be dropped from destination in load step)
# extract_step.extract_storage.new_packages.save_dropped_tables(
# load_id, dropped_tables
# )
self._extract_source(extract_step, source, max_parallel_items, workers)
# extract state
if self.config.restore_from_destination:
# this will update state version hash so it will not be extracted again by with_state_sync
Expand Down
5 changes: 4 additions & 1 deletion tests/pipeline/test_refresh_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def some_data_3():

# First run pipeline with load to destination so tables are created
pipeline = dlt.pipeline(
"refresh_full_test", destination="duckdb", refresh="full", dataset_name="refresh_full_test"
"refresh_full_test",
destination="duckdb",
refresh="drop_dataset",
dataset_name="refresh_full_test",
)

info = pipeline.run(my_source())
Expand Down

0 comments on commit 187d1b0

Please sign in to comment.