Skip to content

Commit

Permalink
Preliminary wireframe #1055
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Mar 9, 2024
1 parent dc5d4b0 commit d736dee
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 44 deletions.
85 changes: 55 additions & 30 deletions dlt/common/data_writers/escape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
import base64
from typing import Any, Dict
import re
from datetime import date, datetime, time # noqa: I251
from typing import Any, Dict

from dlt.common.json import json


# use regex to escape characters in single pass
SQL_ESCAPE_DICT = {"'": "''", "\\": "\\\\", "\n": "\\n", "\r": "\\r"}

Expand All @@ -24,57 +25,48 @@ def _escape_extended(
) -> str:
escape_dict = escape_dict or SQL_ESCAPE_DICT
escape_re = escape_re or SQL_ESCAPE_RE
return "{}{}{}".format(prefix, escape_re.sub(lambda x: escape_dict[x.group(0)], v), "'")
return f"{prefix}{escape_re.sub(lambda x: escape_dict[x.group(0)], v)}'"


def escape_redshift_literal(v: Any) -> Any:
if isinstance(v, str):
# https://www.postgresql.org/docs/9.3/sql-syntax-lexical.html
# looks like this is the only thing we need to escape for Postgres > 9.1
# redshift keeps \ as escape character which is pre 9 behavior
# redshift keeps \ as escape character which is pre-9 behavior.
return _escape_extended(v, prefix="'")
if isinstance(v, bytes):
return f"from_hex('{v.hex()}')"
if isinstance(v, (datetime, date, time)):
return f"'{v.isoformat()}'"
if isinstance(v, (list, dict)):
return "json_parse(%s)" % _escape_extended(json.dumps(v), prefix="'")
if v is None:
return "NULL"

return str(v)
return "NULL" if v is None else str(v)


def escape_postgres_literal(v: Any) -> Any:
if isinstance(v, str):
# we escape extended string which behave like the redshift string
# we escape extended string which behaves like the redshift string.
return _escape_extended(v)
if isinstance(v, (datetime, date, time)):
return f"'{v.isoformat()}'"
if isinstance(v, (list, dict)):
return _escape_extended(json.dumps(v))
if isinstance(v, bytes):
return f"'\\x{v.hex()}'"
if v is None:
return "NULL"

return str(v)
return "NULL" if v is None else str(v)


def escape_duckdb_literal(v: Any) -> Any:
if isinstance(v, str):
# we escape extended string which behave like the redshift string
# We escape extended string which behaves like the redshift string.
return _escape_extended(v)
if isinstance(v, (datetime, date, time)):
return f"'{v.isoformat()}'"
if isinstance(v, (list, dict)):
return _escape_extended(json.dumps(v))
if isinstance(v, bytes):
return f"from_base64('{base64.b64encode(v).decode('ascii')}')"
if v is None:
return "NULL"

return str(v)
return "NULL" if v is None else str(v)


MS_SQL_ESCAPE_DICT = {
Expand All @@ -100,17 +92,12 @@ def escape_mssql_literal(v: Any) -> Any:
if isinstance(v, bytes):
from dlt.destinations.impl.mssql.mssql import VARBINARY_MAX_N

if len(v) <= VARBINARY_MAX_N:
n = str(len(v))
else:
n = "MAX"
n = str(len(v)) if len(v) <= VARBINARY_MAX_N else "MAX"
return f"CONVERT(VARBINARY({n}), '{v.hex()}', 2)"

if isinstance(v, bool):
return str(int(v))
if v is None:
return "NULL"
return str(v)
return "NULL" if v is None else str(v)


def escape_redshift_identifier(v: str) -> str:
Expand All @@ -127,8 +114,8 @@ def escape_bigquery_identifier(v: str) -> str:


def escape_snowflake_identifier(v: str) -> str:
# Snowcase uppercase all identifiers unless quoted. Match this here so queries on information schema work without issue
# See also https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers
# Snowflake uppercase all identifiers unless quoted. Match this here so queries on information schema work without issue.
# See https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers.
return escape_postgres_identifier(v.upper())


Expand All @@ -147,7 +134,45 @@ def escape_databricks_literal(v: Any) -> Any:
return _escape_extended(json.dumps(v), prefix="'", escape_dict=DATABRICKS_ESCAPE_DICT)
if isinstance(v, bytes):
return f"X'{v.hex()}'"
if v is None:
return "NULL"
return "NULL" if v is None else str(v)


# https://github.com/ClickHouse/ClickHouse/blob/master/docs/en/sql-reference/syntax.md#string
CLICKHOUSE_ESCAPE_DICT = {
"'": "''",
"\\": "\\\\",
"\n": "\\n",
"\t": "\\t",
"\b": "\\b",
"\f": "\\f",
"\r": "\\r",
"\0": "\\0",
"\a": "\\a",
"\v": "\\v",
}

CLICKHOUSE_ESCAPE_RE = _make_sql_escape_re(CLICKHOUSE_ESCAPE_DICT)


def escape_clickhouse_literal(v: Any) -> Any:
if isinstance(v, str):
return _escape_extended(
v, prefix="'", escape_dict=CLICKHOUSE_ESCAPE_DICT, escape_re=CLICKHOUSE_ESCAPE_RE
)
if isinstance(v, (datetime, date, time)):
return f"'{v.isoformat()}'"
if isinstance(v, (list, dict)):
return _escape_extended(
json.dumps(v),
prefix="'",
escape_dict=CLICKHOUSE_ESCAPE_DICT,
escape_re=CLICKHOUSE_ESCAPE_RE,
)
if isinstance(v, bytes):
return f"'{v.hex()}'"
return "NULL" if v is None else str(v)


return str(v)
def escape_clickhouse_identifier(v: str, quote_char: str = "`") -> str:
quote_char = quote_char if quote_char in {'"', "`"} else "`"
return quote_char + v.replace(quote_char, quote_char * 2).replace("\\", "\\\\") + quote_char
24 changes: 24 additions & 0 deletions dlt/destinations/impl/clickhouse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE
from dlt.common.data_writers.escape import escape_clickhouse_identifier, escape_clickhouse_literal
from dlt.common.destination import DestinationCapabilitiesContext


def capabilities() -> DestinationCapabilitiesContext:
caps = DestinationCapabilitiesContext()
caps.preferred_loader_file_format = "jsonl"
caps.supported_loader_file_formats = ["jsonl", "parquet"]
caps.preferred_staging_file_format = "parquet"
caps.supported_staging_file_formats = ["parquet", "jsonl"]
caps.escape_identifier = escape_clickhouse_identifier
caps.escape_literal = escape_clickhouse_literal
caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE)
caps.wei_precision = (76, 38)
caps.max_identifier_length = 1024
caps.max_column_identifier_length = 300
caps.max_query_length = 1024 * 1024
caps.is_max_query_length_in_bytes = False
caps.max_text_data_type_length = 10 * 1024 * 1024
caps.is_max_text_data_type_length_in_bytes = True
caps.supports_ddl_transactions = False

return caps
Empty file.
39 changes: 39 additions & 0 deletions dlt/destinations/impl/clickhouse/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import TYPE_CHECKING, ClassVar, List, Optional, Final

from dlt.common.configuration import configspec
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration


@configspec
class ClickhouseClientConfiguration(DestinationClientDwhWithStagingConfiguration):
destination_type: Final[str] = "clickhouse" # type: ignore

http_timeout: float = 15.0
file_upload_timeout: float = 30 * 60.0
retry_deadline: float = 60.0

__config_gen_annotations__: ClassVar[List[str]] = []

if TYPE_CHECKING:

def __init__(
self,
*,
dataset_name: str = None,
default_schema_name: Optional[str],
http_timeout: float = 15.0,
file_upload_timeout: float = 30 * 60.0,
retry_deadline: float = 60.0,
destination_name: str = None,
environment: str = None
) -> None:
super().__init__(
dataset_name=dataset_name,
default_schema_name=default_schema_name,
destination_name=destination_name,
environment=environment,
)
self.retry_deadline = retry_deadline
self.file_upload_timeout = file_upload_timeout
self.http_timeout = http_timeout
...
Empty file.
Empty file.
21 changes: 7 additions & 14 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@


class SqlLoadJob(LoadJob):
"""A job executing sql statement, without followup trait"""
"""A job executing sql statement, without followup trait."""

def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))
Expand Down Expand Up @@ -98,13 +98,10 @@ def exception(self) -> str:
raise NotImplementedError()

def _string_containts_ddl_queries(self, sql: str) -> bool:
for cmd in DDL_COMMANDS:
if re.search(cmd, sql, re.IGNORECASE):
return True
return False
return any(re.search(cmd, sql, re.IGNORECASE) for cmd in DDL_COMMANDS)

def _split_fragments(self, sql: str) -> List[str]:
return [s + (";" if not s.endswith(";") else "") for s in sql.split(";") if s.strip()]
return [s + ("" if s.endswith(";") else ";") for s in sql.split(";") if s.strip()]

@staticmethod
def is_sql_job(file_path: str) -> bool:
Expand Down Expand Up @@ -496,7 +493,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non

@staticmethod
def _gen_not_null(v: bool) -> str:
return "NOT NULL" if not v else ""
return "" if v else "NOT NULL"

def _create_table_update(
self, table_name: str, storage_columns: TTableSchemaColumns
Expand All @@ -518,13 +515,9 @@ def _row_to_schema_info(self, query: str, *args: Any) -> StorageSchemaInfo:
# get schema as string
# TODO: Re-use decompress/compress_state() implementation from dlt.pipeline.state_sync
schema_str: str = row[5]
try:
with contextlib.suppress(ValueError):
schema_bytes = base64.b64decode(schema_str, validate=True)
schema_str = zlib.decompress(schema_bytes).decode("utf-8")
except ValueError:
# not a base64 string
pass

# make utc datetime
inserted_at = pendulum.instance(row[4])

Expand All @@ -539,13 +532,13 @@ def _replace_schema_in_storage(self, schema: Schema) -> None:
self._update_schema_in_storage(schema)

def _update_schema_in_storage(self, schema: Schema) -> None:
# make sure that schema being saved was not modified from the moment it was loaded from storage
# Make sure the schema being saved wasn't modified from the moment it was loaded from storage.
version_hash = schema.version_hash
if version_hash != schema.stored_version_hash:
raise DestinationSchemaTampered(schema.name, version_hash, schema.stored_version_hash)
# get schema string or zip
schema_str = json.dumps(schema.to_dict())
# TODO: not all databases store data as utf-8 but this exception is mostly for redshift
# TODO: not all databases store data as utf-8 but this exception is mostly for redshift.
schema_bytes = schema_str.encode("utf-8")
if len(schema_bytes) > self.capabilities.max_text_data_type_length:
# compress and to base64
Expand Down

0 comments on commit d736dee

Please sign in to comment.