diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 4fb97e51d2..ae2f8a10e1 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -27,7 +27,13 @@ TTableSchemaColumns, TColumnSchemaBase, ) +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_first_column_name_with_prop, + get_dedup_sort_tuple, +) from dlt.common.storages import FileStorage +from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.impl.clickhouse import capabilities from dlt.destinations.impl.clickhouse.clickhouse_adapter import ( TTableEngineType, @@ -48,7 +54,7 @@ SqlJobClientBase, ) from dlt.destinations.job_impl import NewReferenceJob, EmptyLoadJob -from dlt.destinations.sql_jobs import SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlMergeJob from dlt.destinations.type_mapping import TypeMapper @@ -200,6 +206,165 @@ def exception(self) -> str: raise NotImplementedError() +class ClickhouseMergeJob(SqlMergeJob): + @classmethod + def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: + # Different sessions are created during the load process, and temporary tables + # do not persist between sessions. + # Resorting to persisted in-memory table to fix. + # return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};" + return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};" + + @classmethod + def gen_merge_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: ClickhouseSqlClient # type: ignore[override] + ) -> List[str]: + sql: List[str] = [] + root_table = table_chain[0] + + escape_id = sql_client.capabilities.escape_identifier + escape_lit = sql_client.capabilities.escape_literal + if escape_id is None: + escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier + if escape_lit is None: + escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal + + root_table_name = sql_client.make_qualified_table_name(root_table["name"]) + with sql_client.with_staging_dataset(staging=True): + staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) + primary_keys = list( + map( + escape_id, + get_columns_names_with_prop(root_table, "primary_key"), + ) + ) + merge_keys = list( + map( + escape_id, + get_columns_names_with_prop(root_table, "merge_key"), + ) + ) + key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys) + + unique_column: str = None + root_key_column: str = None + + if len(table_chain) == 1: + key_table_clauses = cls.gen_key_table_clauses( + root_table_name, staging_root_table_name, key_clauses, for_delete=True + ) + sql.extend(f"DELETE {clause};" for clause in key_table_clauses) + else: + key_table_clauses = cls.gen_key_table_clauses( + root_table_name, staging_root_table_name, key_clauses, for_delete=False + ) + unique_columns = get_columns_names_with_prop(root_table, "unique") + if not unique_columns: + raise MergeDispositionException( + sql_client.fully_qualified_dataset_name(), + staging_root_table_name, + [t["name"] for t in table_chain], + f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so" + " it is not possible to link child tables to it.", + ) + unique_column = escape_id(unique_columns[0]) + create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql( + unique_column, key_table_clauses + ) + sql.extend(create_delete_temp_table_sql) + + for table in table_chain[1:]: + table_name = sql_client.make_qualified_table_name(table["name"]) + root_key_columns = get_columns_names_with_prop(table, "root_key") + if not root_key_columns: + raise MergeDispositionException( + sql_client.fully_qualified_dataset_name(), + staging_root_table_name, + [t["name"] for t in table_chain], + "There is no root foreign key (ie _dlt_root_id) in child table" + f" {table['name']} so it is not possible to refer to top level table" + f" {root_table['name']} unique column {unique_column}", + ) + root_key_column = escape_id(root_key_columns[0]) + sql.append( + cls.gen_delete_from_sql( + table_name, root_key_column, delete_temp_table_name, unique_column + ) + ) + + sql.append( + cls.gen_delete_from_sql( + root_table_name, unique_column, delete_temp_table_name, unique_column + ) + ) + + not_deleted_cond: str = None + hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") + if hard_delete_col is not None: + not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" + if root_table["columns"][hard_delete_col]["data_type"] == "bool": + not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" + + dedup_sort = get_dedup_sort_tuple(root_table) + + insert_temp_table_name: str = None + if len(table_chain) > 1 and (primary_keys or hard_delete_col is not None): + condition_columns = [hard_delete_col] if not_deleted_cond is not None else None + ( + create_insert_temp_table_sql, + insert_temp_table_name, + ) = cls.gen_insert_temp_table_sql( + staging_root_table_name, + primary_keys, + unique_column, + dedup_sort, + not_deleted_cond, + condition_columns, + ) + sql.extend(create_insert_temp_table_sql) + + to_delete: List[str] = [] + + for table in table_chain: + table_name = sql_client.make_qualified_table_name(table["name"]) + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + + insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1" + if ( + primary_keys + and len(table_chain) > 1 + or not primary_keys + and table.get("parent") is not None + and hard_delete_col is not None + ): + uniq_column = unique_column if table.get("parent") is None else root_key_column + insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" + + columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) + col_str = ", ".join(columns) + select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" + if primary_keys and len(table_chain) == 1: + select_sql = cls.gen_select_from_dedup_sql( + staging_table_name, primary_keys, columns, dedup_sort, insert_cond + ) + + sql.extend([f"INSERT INTO {table_name}({col_str}) {select_sql};"]) + + if table_name is not None and table_name.startswith("delete_"): + to_delete.extend([table_name]) + if insert_temp_table_name is not None and insert_temp_table_name.startswith("delete_"): + to_delete.extend([insert_temp_table_name]) + + # TODO: Doesn't remove all `delete_` tables. + for delete_table_name in to_delete: + sql.extend( + [f"DROP TABLE IF EXISTS {sql_client.make_qualified_table_name(delete_table_name)};"] + ) + + return sql + + class ClickhouseClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -319,16 +484,5 @@ def _from_db_type( def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - return [ - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False}) - ] - - def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})] - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - # Fall back to append jobs for merge. - return self._create_append_followup_jobs(table_chain) + return [ClickhouseMergeJob.from_table_chain(table_chain, self.sql_client)] diff --git a/tests/load/pipeline/test_clickhouse.py b/tests/load/pipeline/test_clickhouse.py index 4a5903137b..1fd834389f 100644 --- a/tests/load/pipeline/test_clickhouse.py +++ b/tests/load/pipeline/test_clickhouse.py @@ -81,3 +81,74 @@ def items2() -> Iterator[TDataItem]: finally: with pipeline.sql_client() as client: client.drop_dataset() + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(all_staging_configs=True, subset=["clickhouse"]), + ids=lambda x: x.name, +) +def test_clickhouse_destination_merge(destination_config: DestinationTestConfiguration) -> None: + pipeline = destination_config.setup_pipeline(f"clickhouse_{uniq_id()}", full_refresh=True) + + try: + + @dlt.resource(name="items") + def items() -> Iterator[TDataItem]: + yield { + "id": 1, + "name": "item", + "sub_items": [ + {"id": 101, "name": "sub item 101"}, + {"id": 101, "name": "sub item 102"}, + ], + } + + pipeline.run( + items, + loader_file_format=destination_config.file_format, + staging=destination_config.staging, + ) + + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] + ) + assert table_counts["items"] == 1 + assert table_counts["items__sub_items"] == 2 + assert table_counts["_dlt_loads"] == 1 + + # Load again with schema evolution. + @dlt.resource(name="items", write_disposition="merge", primary_key="id") + def items2() -> Iterator[TDataItem]: + yield { + "id": 1, + "name": "item", + "new_field": "hello", + "sub_items": [ + { + "id": 101, + "name": "sub item 101", + "other_new_field": "hello 101", + }, + { + "id": 101, + "name": "sub item 102", + "other_new_field": "hello 102", + }, + ], + } + + pipeline.run(items2) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] + ) + assert table_counts["items"] == 1 + assert table_counts["items__sub_items"] == 2 + assert table_counts["_dlt_loads"] == 2 + + except Exception as e: + raise e + + finally: + with pipeline.sql_client() as client: + client.drop_dataset()