Skip to content

Commit

Permalink
Revert back to merge implementation #1055
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Apr 3, 2024
1 parent 34646ad commit 8bc3c21
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 13 deletions.
180 changes: 167 additions & 13 deletions dlt/destinations/impl/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
TTableSchemaColumns,
TColumnSchemaBase,
)
from dlt.common.schema.utils import (
get_columns_names_with_prop,
get_first_column_name_with_prop,
get_dedup_sort_tuple,
)
from dlt.common.storages import FileStorage
from dlt.destinations.exceptions import MergeDispositionException
from dlt.destinations.impl.clickhouse import capabilities
from dlt.destinations.impl.clickhouse.clickhouse_adapter import (
TTableEngineType,
Expand All @@ -48,7 +54,7 @@
SqlJobClientBase,
)
from dlt.destinations.job_impl import NewReferenceJob, EmptyLoadJob
from dlt.destinations.sql_jobs import SqlStagingCopyJob
from dlt.destinations.sql_jobs import SqlMergeJob
from dlt.destinations.type_mapping import TypeMapper


Expand Down Expand Up @@ -200,6 +206,165 @@ def exception(self) -> str:
raise NotImplementedError()


class ClickhouseMergeJob(SqlMergeJob):
@classmethod
def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str:
# Different sessions are created during the load process, and temporary tables
# do not persist between sessions.
# Resorting to persisted in-memory table to fix.
# return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};"
return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};"

@classmethod
def gen_merge_sql(
cls, table_chain: Sequence[TTableSchema], sql_client: ClickhouseSqlClient # type: ignore[override]
) -> List[str]:
sql: List[str] = []
root_table = table_chain[0]

escape_id = sql_client.capabilities.escape_identifier
escape_lit = sql_client.capabilities.escape_literal
if escape_id is None:
escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier
if escape_lit is None:
escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal

root_table_name = sql_client.make_qualified_table_name(root_table["name"])
with sql_client.with_staging_dataset(staging=True):
staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"])
primary_keys = list(
map(
escape_id,
get_columns_names_with_prop(root_table, "primary_key"),
)
)
merge_keys = list(
map(
escape_id,
get_columns_names_with_prop(root_table, "merge_key"),
)
)
key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys)

unique_column: str = None
root_key_column: str = None

if len(table_chain) == 1:
key_table_clauses = cls.gen_key_table_clauses(
root_table_name, staging_root_table_name, key_clauses, for_delete=True
)
sql.extend(f"DELETE {clause};" for clause in key_table_clauses)
else:
key_table_clauses = cls.gen_key_table_clauses(
root_table_name, staging_root_table_name, key_clauses, for_delete=False
)
unique_columns = get_columns_names_with_prop(root_table, "unique")
if not unique_columns:
raise MergeDispositionException(
sql_client.fully_qualified_dataset_name(),
staging_root_table_name,
[t["name"] for t in table_chain],
f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so"
" it is not possible to link child tables to it.",
)
unique_column = escape_id(unique_columns[0])
create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql(
unique_column, key_table_clauses
)
sql.extend(create_delete_temp_table_sql)

for table in table_chain[1:]:
table_name = sql_client.make_qualified_table_name(table["name"])
root_key_columns = get_columns_names_with_prop(table, "root_key")
if not root_key_columns:
raise MergeDispositionException(
sql_client.fully_qualified_dataset_name(),
staging_root_table_name,
[t["name"] for t in table_chain],
"There is no root foreign key (ie _dlt_root_id) in child table"
f" {table['name']} so it is not possible to refer to top level table"
f" {root_table['name']} unique column {unique_column}",
)
root_key_column = escape_id(root_key_columns[0])
sql.append(
cls.gen_delete_from_sql(
table_name, root_key_column, delete_temp_table_name, unique_column
)
)

sql.append(
cls.gen_delete_from_sql(
root_table_name, unique_column, delete_temp_table_name, unique_column
)
)

not_deleted_cond: str = None
hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete")
if hard_delete_col is not None:
not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL"
if root_table["columns"][hard_delete_col]["data_type"] == "bool":
not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}"

dedup_sort = get_dedup_sort_tuple(root_table)

insert_temp_table_name: str = None
if len(table_chain) > 1 and (primary_keys or hard_delete_col is not None):
condition_columns = [hard_delete_col] if not_deleted_cond is not None else None
(
create_insert_temp_table_sql,
insert_temp_table_name,
) = cls.gen_insert_temp_table_sql(
staging_root_table_name,
primary_keys,
unique_column,
dedup_sort,
not_deleted_cond,
condition_columns,
)
sql.extend(create_insert_temp_table_sql)

to_delete: List[str] = []

for table in table_chain:
table_name = sql_client.make_qualified_table_name(table["name"])
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])

insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1"
if (
primary_keys
and len(table_chain) > 1
or not primary_keys
and table.get("parent") is not None
and hard_delete_col is not None
):
uniq_column = unique_column if table.get("parent") is None else root_key_column
insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})"

columns = list(map(escape_id, get_columns_names_with_prop(table, "name")))
col_str = ", ".join(columns)
select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}"
if primary_keys and len(table_chain) == 1:
select_sql = cls.gen_select_from_dedup_sql(
staging_table_name, primary_keys, columns, dedup_sort, insert_cond
)

sql.extend([f"INSERT INTO {table_name}({col_str}) {select_sql};"])

if table_name is not None and table_name.startswith("delete_"):
to_delete.extend([table_name])
if insert_temp_table_name is not None and insert_temp_table_name.startswith("delete_"):
to_delete.extend([insert_temp_table_name])

# TODO: Doesn't remove all `delete_` tables.
for delete_table_name in to_delete:
sql.extend(
[f"DROP TABLE IF EXISTS {sql_client.make_qualified_table_name(delete_table_name)};"]
)

return sql


class ClickhouseClient(SqlJobClientWithStaging, SupportsStagingDestination):
capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

Expand Down Expand Up @@ -319,16 +484,5 @@ def _from_db_type(
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
return [
SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False})
]

def _create_replace_followup_jobs(
self, table_chain: Sequence[TTableSchema]
) -> List[NewLoadJob]:
return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})]

def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
# Fall back to append jobs for merge.
return self._create_append_followup_jobs(table_chain)
return [ClickhouseMergeJob.from_table_chain(table_chain, self.sql_client)]
71 changes: 71 additions & 0 deletions tests/load/pipeline/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,74 @@ def items2() -> Iterator[TDataItem]:
finally:
with pipeline.sql_client() as client:
client.drop_dataset()


@pytest.mark.parametrize(
"destination_config",
destinations_configs(all_staging_configs=True, subset=["clickhouse"]),
ids=lambda x: x.name,
)
def test_clickhouse_destination_merge(destination_config: DestinationTestConfiguration) -> None:
pipeline = destination_config.setup_pipeline(f"clickhouse_{uniq_id()}", full_refresh=True)

try:

@dlt.resource(name="items")
def items() -> Iterator[TDataItem]:
yield {
"id": 1,
"name": "item",
"sub_items": [
{"id": 101, "name": "sub item 101"},
{"id": 101, "name": "sub item 102"},
],
}

pipeline.run(
items,
loader_file_format=destination_config.file_format,
staging=destination_config.staging,
)

table_counts = load_table_counts(
pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()]
)
assert table_counts["items"] == 1
assert table_counts["items__sub_items"] == 2
assert table_counts["_dlt_loads"] == 1

# Load again with schema evolution.
@dlt.resource(name="items", write_disposition="merge", primary_key="id")
def items2() -> Iterator[TDataItem]:
yield {
"id": 1,
"name": "item",
"new_field": "hello",
"sub_items": [
{
"id": 101,
"name": "sub item 101",
"other_new_field": "hello 101",
},
{
"id": 101,
"name": "sub item 102",
"other_new_field": "hello 102",
},
],
}

pipeline.run(items2)
table_counts = load_table_counts(
pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()]
)
assert table_counts["items"] == 1
assert table_counts["items__sub_items"] == 2
assert table_counts["_dlt_loads"] == 2

except Exception as e:
raise e

finally:
with pipeline.sql_client() as client:
client.drop_dataset()

0 comments on commit 8bc3c21

Please sign in to comment.