diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index ae2f8a10e1..4fb97e51d2 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -27,13 +27,7 @@ TTableSchemaColumns, TColumnSchemaBase, ) -from dlt.common.schema.utils import ( - get_columns_names_with_prop, - get_first_column_name_with_prop, - get_dedup_sort_tuple, -) from dlt.common.storages import FileStorage -from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.impl.clickhouse import capabilities from dlt.destinations.impl.clickhouse.clickhouse_adapter import ( TTableEngineType, @@ -54,7 +48,7 @@ SqlJobClientBase, ) from dlt.destinations.job_impl import NewReferenceJob, EmptyLoadJob -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob from dlt.destinations.type_mapping import TypeMapper @@ -206,165 +200,6 @@ def exception(self) -> str: raise NotImplementedError() -class ClickhouseMergeJob(SqlMergeJob): - @classmethod - def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: - # Different sessions are created during the load process, and temporary tables - # do not persist between sessions. - # Resorting to persisted in-memory table to fix. - # return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};" - return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};" - - @classmethod - def gen_merge_sql( - cls, table_chain: Sequence[TTableSchema], sql_client: ClickhouseSqlClient # type: ignore[override] - ) -> List[str]: - sql: List[str] = [] - root_table = table_chain[0] - - escape_id = sql_client.capabilities.escape_identifier - escape_lit = sql_client.capabilities.escape_literal - if escape_id is None: - escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier - if escape_lit is None: - escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal - - root_table_name = sql_client.make_qualified_table_name(root_table["name"]) - with sql_client.with_staging_dataset(staging=True): - staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"]) - primary_keys = list( - map( - escape_id, - get_columns_names_with_prop(root_table, "primary_key"), - ) - ) - merge_keys = list( - map( - escape_id, - get_columns_names_with_prop(root_table, "merge_key"), - ) - ) - key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys) - - unique_column: str = None - root_key_column: str = None - - if len(table_chain) == 1: - key_table_clauses = cls.gen_key_table_clauses( - root_table_name, staging_root_table_name, key_clauses, for_delete=True - ) - sql.extend(f"DELETE {clause};" for clause in key_table_clauses) - else: - key_table_clauses = cls.gen_key_table_clauses( - root_table_name, staging_root_table_name, key_clauses, for_delete=False - ) - unique_columns = get_columns_names_with_prop(root_table, "unique") - if not unique_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - f"There is no unique column (ie _dlt_id) in top table {root_table['name']} so" - " it is not possible to link child tables to it.", - ) - unique_column = escape_id(unique_columns[0]) - create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql( - unique_column, key_table_clauses - ) - sql.extend(create_delete_temp_table_sql) - - for table in table_chain[1:]: - table_name = sql_client.make_qualified_table_name(table["name"]) - root_key_columns = get_columns_names_with_prop(table, "root_key") - if not root_key_columns: - raise MergeDispositionException( - sql_client.fully_qualified_dataset_name(), - staging_root_table_name, - [t["name"] for t in table_chain], - "There is no root foreign key (ie _dlt_root_id) in child table" - f" {table['name']} so it is not possible to refer to top level table" - f" {root_table['name']} unique column {unique_column}", - ) - root_key_column = escape_id(root_key_columns[0]) - sql.append( - cls.gen_delete_from_sql( - table_name, root_key_column, delete_temp_table_name, unique_column - ) - ) - - sql.append( - cls.gen_delete_from_sql( - root_table_name, unique_column, delete_temp_table_name, unique_column - ) - ) - - not_deleted_cond: str = None - hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") - if hard_delete_col is not None: - not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" - if root_table["columns"][hard_delete_col]["data_type"] == "bool": - not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" - - dedup_sort = get_dedup_sort_tuple(root_table) - - insert_temp_table_name: str = None - if len(table_chain) > 1 and (primary_keys or hard_delete_col is not None): - condition_columns = [hard_delete_col] if not_deleted_cond is not None else None - ( - create_insert_temp_table_sql, - insert_temp_table_name, - ) = cls.gen_insert_temp_table_sql( - staging_root_table_name, - primary_keys, - unique_column, - dedup_sort, - not_deleted_cond, - condition_columns, - ) - sql.extend(create_insert_temp_table_sql) - - to_delete: List[str] = [] - - for table in table_chain: - table_name = sql_client.make_qualified_table_name(table["name"]) - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - - insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1" - if ( - primary_keys - and len(table_chain) > 1 - or not primary_keys - and table.get("parent") is not None - and hard_delete_col is not None - ): - uniq_column = unique_column if table.get("parent") is None else root_key_column - insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" - - columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) - col_str = ", ".join(columns) - select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" - if primary_keys and len(table_chain) == 1: - select_sql = cls.gen_select_from_dedup_sql( - staging_table_name, primary_keys, columns, dedup_sort, insert_cond - ) - - sql.extend([f"INSERT INTO {table_name}({col_str}) {select_sql};"]) - - if table_name is not None and table_name.startswith("delete_"): - to_delete.extend([table_name]) - if insert_temp_table_name is not None and insert_temp_table_name.startswith("delete_"): - to_delete.extend([insert_temp_table_name]) - - # TODO: Doesn't remove all `delete_` tables. - for delete_table_name in to_delete: - sql.extend( - [f"DROP TABLE IF EXISTS {sql_client.make_qualified_table_name(delete_table_name)};"] - ) - - return sql - - class ClickhouseClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -484,5 +319,16 @@ def _from_db_type( def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") + def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [ + SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False}) + ] + + def _create_replace_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: + return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})] + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - return [ClickhouseMergeJob.from_table_chain(table_chain, self.sql_client)] + # Fall back to append jobs for merge. + return self._create_append_followup_jobs(table_chain) diff --git a/tests/load/pipeline/test_clickhouse.py b/tests/load/pipeline/test_clickhouse.py index 1fd834389f..4a5903137b 100644 --- a/tests/load/pipeline/test_clickhouse.py +++ b/tests/load/pipeline/test_clickhouse.py @@ -81,74 +81,3 @@ def items2() -> Iterator[TDataItem]: finally: with pipeline.sql_client() as client: client.drop_dataset() - - -@pytest.mark.parametrize( - "destination_config", - destinations_configs(all_staging_configs=True, subset=["clickhouse"]), - ids=lambda x: x.name, -) -def test_clickhouse_destination_merge(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline(f"clickhouse_{uniq_id()}", full_refresh=True) - - try: - - @dlt.resource(name="items") - def items() -> Iterator[TDataItem]: - yield { - "id": 1, - "name": "item", - "sub_items": [ - {"id": 101, "name": "sub item 101"}, - {"id": 101, "name": "sub item 102"}, - ], - } - - pipeline.run( - items, - loader_file_format=destination_config.file_format, - staging=destination_config.staging, - ) - - table_counts = load_table_counts( - pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] - ) - assert table_counts["items"] == 1 - assert table_counts["items__sub_items"] == 2 - assert table_counts["_dlt_loads"] == 1 - - # Load again with schema evolution. - @dlt.resource(name="items", write_disposition="merge", primary_key="id") - def items2() -> Iterator[TDataItem]: - yield { - "id": 1, - "name": "item", - "new_field": "hello", - "sub_items": [ - { - "id": 101, - "name": "sub item 101", - "other_new_field": "hello 101", - }, - { - "id": 101, - "name": "sub item 102", - "other_new_field": "hello 102", - }, - ], - } - - pipeline.run(items2) - table_counts = load_table_counts( - pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] - ) - assert table_counts["items"] == 1 - assert table_counts["items__sub_items"] == 2 - assert table_counts["_dlt_loads"] == 2 - - except Exception as e: - raise e - - finally: - with pipeline.sql_client() as client: - client.drop_dataset()