Skip to content

Commit

Permalink
mssql destination (#611)
Browse files Browse the repository at this point in the history
* remove snipsync (#613)

* add custom snippets element

* Add pymssql dependency

* Quote keyword in version table columns

* Init ms sql (not working)

* Partially working mssql

* ms sql ci template init

* Use pyodbc instead of pymssql

* make database exception

* ms sql escaping

* Replace pymysql with pyodbc dependency

* Fix binary+json escape, time type

* autocommit mode

* Include pyarrow

* test_sql_client running

* Add dbt map

* try sql ci all OS

* Run with x

* Try install driver

* Configure + detect odbc_driver, prefer v18

* skip dbt tests for now

* no arrow

* escape time

* Temp table merge queries mssql compat

* escape all version/state columns

* Remove fs credentials

* Fix double ;

* Fix py 3.8

* Init ms sql docs

* Forward query params

* All load tests running

* Convert to odbc dsn

* Skip marker for filesystem pipeline test

* Limit ms sql inserts to 1000 rows

* Cleanup

* Do NULLS FIRST sorting in python

* Lint pq import

* Fix athena tests

* -x

* disables index creation by default, updates docs

* Revert "remove snipsync (#613)"

This reverts commit 34dcc91.

---------

Co-authored-by: David Scharf <[email protected]>
Co-authored-by: Marcin Rudolf <[email protected]>
  • Loading branch information
3 people authored Sep 12, 2023
1 parent 82d8b9d commit e8fecb7
Show file tree
Hide file tree
Showing 27 changed files with 4,162 additions and 3,663 deletions.
89 changes: 89 additions & 0 deletions .github/workflows/test_destination_mssql.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@

name: test MS SQL

on:
pull_request:
branches:
- master
- devel
workflow_dispatch:

env:
DESTINATION__MSSQL__CREDENTIALS: mssql://[email protected]:1433/dlt_ci
DESTINATION__MSSQL__CREDENTIALS__PASSWORD: ${{ secrets.MSSQL_PASSWORD }}

RUNTIME__SENTRY_DSN: https://[email protected]/4504819859914752
RUNTIME__LOG_LEVEL: ERROR

ACTIVE_DESTINATIONS: "[\"mssql\"]"
ALL_FILESYSTEM_DRIVERS: "[\"memory\"]"

jobs:
get_docs_changes:
uses: ./.github/workflows/get_docs_changes.yml
if: ${{ !github.event.pull_request.head.repo.fork }}

run_loader:
name: Tests MS SQL loader
needs: get_docs_changes
if: needs.get_docs_changes.outputs.changes_outside_docs == 'true'
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
defaults:
run:
shell: bash
runs-on: ${{ matrix.os }}

steps:

- name: Check out
uses: actions/checkout@master

- name: Install ODBC driver for SQL Server
run: |
sudo ACCEPT_EULA=Y apt-get install --yes msodbcsql18
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10.x"

- name: Install Poetry
uses: snok/[email protected]
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-gcp

- name: Install dependencies
run: poetry install --no-interaction -E mssql -E s3 -E gs -E az

- run: |
poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py
if: runner.os != 'Windows'
name: Run tests Linux/MAC
- run: |
poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py
if: runner.os == 'Windows'
name: Run tests Windows
shell: cmd
matrix_job_required_check:
name: MS SQL loader tests
needs: run_loader
runs-on: ubuntu-latest
if: always()
steps:
- name: Check matrix job results
if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled')
run: |
echo "One or more matrix job tests failed or were cancelled. You may need to re-run them." && exit 1
41 changes: 36 additions & 5 deletions dlt/common/data_writers/escape.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import re
import base64
from typing import Any
from datetime import date, datetime # noqa: I251
from typing import Any, Dict
from datetime import date, datetime, time # noqa: I251

from dlt.common.json import json

# use regex to escape characters in single pass
SQL_ESCAPE_DICT = {"'": "''", "\\": "\\\\", "\n": "\\n", "\r": "\\r"}
SQL_ESCAPE_RE = re.compile("|".join([re.escape(k) for k in sorted(SQL_ESCAPE_DICT, key=len, reverse=True)]), flags=re.DOTALL)

def _make_sql_escape_re(escape_dict: Dict[str, str]) -> re.Pattern: # type: ignore[type-arg]
return re.compile("|".join([re.escape(k) for k in sorted(escape_dict, key=len, reverse=True)]), flags=re.DOTALL)

def _escape_extended(v: str, prefix:str = "E'") -> str:
return "{}{}{}".format(prefix, SQL_ESCAPE_RE.sub(lambda x: SQL_ESCAPE_DICT[x.group(0)], v), "'")

SQL_ESCAPE_RE = _make_sql_escape_re(SQL_ESCAPE_DICT)

def _escape_extended(
v: str, prefix:str = "E'", escape_dict: Dict[str, str] = None, escape_re: re.Pattern = None # type: ignore[type-arg]
) -> str:
escape_dict = escape_dict or SQL_ESCAPE_DICT
escape_re = escape_re or SQL_ESCAPE_RE
return "{}{}{}".format(prefix, escape_re.sub(lambda x: escape_dict[x.group(0)], v), "'")


def escape_redshift_literal(v: Any) -> Any:
Expand Down Expand Up @@ -58,6 +66,29 @@ def escape_duckdb_literal(v: Any) -> Any:
return str(v)


MS_SQL_ESCAPE_DICT = {
"'": "''",
'\n': "' + CHAR(10) + N'",
'\r': "' + CHAR(13) + N'",
'\t': "' + CHAR(9) + N'",
}
MS_SQL_ESCAPE_RE = _make_sql_escape_re(MS_SQL_ESCAPE_DICT)

def escape_mssql_literal(v: Any) -> Any:
if isinstance(v, str):
return _escape_extended(v, prefix="N'", escape_dict=MS_SQL_ESCAPE_DICT, escape_re=MS_SQL_ESCAPE_RE)
if isinstance(v, (datetime, date, time)):
return f"'{v.isoformat()}'"
if isinstance(v, (list, dict)):
return _escape_extended(json.dumps(v), prefix="N'", escape_dict=MS_SQL_ESCAPE_DICT, escape_re=MS_SQL_ESCAPE_RE)
if isinstance(v, bytes):
base_64_string = base64.b64encode(v).decode('ascii')
return f"""CAST('' AS XML).value('xs:base64Binary("{base_64_string}")', 'VARBINARY(MAX)')"""
if isinstance(v, bool):
return str(int(v))
return str(v)


def escape_redshift_identifier(v: str) -> str:
return '"' + v.replace('"', '""').replace("\\", "\\\\") + '"'

Expand Down
1 change: 1 addition & 0 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
supports_truncate_command: bool = True
schema_supports_numeric_precision: bool = True
timestamp_precision: int = 6
max_rows_per_insert: Optional[int] = None

# do not allow to create default value, destination caps must be always explicitly inserted into container
can_create_default: ClassVar[bool] = False
Expand Down
4 changes: 4 additions & 0 deletions dlt/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def _to_pip_install(self) -> str:
return "\n".join([f"pip install \"{d}\"" for d in self.dependencies])


class SystemConfigurationException(DltException):
pass


class DestinationException(DltException):
pass

Expand Down
32 changes: 27 additions & 5 deletions dlt/destinations/insert_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState
from dlt.common.schema.typing import TTableSchema
from dlt.common.storages import FileStorage
from dlt.common.utils import chunks

from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.job_impl import EmptyLoadJob
Expand Down Expand Up @@ -37,10 +38,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st
# properly formatted file has a values marker at the beginning
assert values_mark == "VALUES\n"

max_rows = self._sql_client.capabilities.max_rows_per_insert

insert_sql = []
while content := f.read(self._sql_client.capabilities.max_query_length // 2):
# write INSERT
insert_sql.extend([header.format(qualified_table_name), values_mark, content])
# read one more line in order to
# 1. complete the content which ends at "random" position, not an end line
# 2. to modify its ending without a need to re-allocating the 8MB of "content"
Expand All @@ -55,13 +56,35 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st
if not is_eof:
# print(f'replace the "," with " {until_nl} {len(insert_sql)}')
until_nl = until_nl[:-1] + ";"

if max_rows is not None:
# mssql has a limit of 1000 rows per INSERT, so we need to split into separate statements
values_rows = content.splitlines(keepends=True)
len_rows = len(values_rows)
processed = 0
# Chunk by max_rows - 1 for simplicity because one more row may be added
for chunk in chunks(values_rows, max_rows - 1):
processed += len(chunk)
insert_sql.extend([header.format(qualified_table_name), values_mark])
if processed == len_rows:
# On the last chunk we need to add the extra row read
insert_sql.append("".join(chunk) + until_nl)
else:
# Replace the , with ;
insert_sql.append("".join(chunk).strip()[:-1] + ";\n")
else:
# otherwise write all content in a single INSERT INTO
insert_sql.extend([header.format(qualified_table_name), values_mark, content])

if until_nl:
insert_sql.append(until_nl)

# actually this may be empty if we were able to read a full file into content
if until_nl:
insert_sql.append(until_nl)
if not is_eof:
# execute chunk of insert
yield insert_sql
insert_sql = []

if insert_sql:
yield insert_sql

Expand Down Expand Up @@ -101,4 +124,3 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
# def _get_out_table_constrains_sql(self, t: TTableSchema) -> str:
# # set non unique indexes
# pass

15 changes: 9 additions & 6 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,13 @@ def state(self) -> TLoadJobState:

class SqlJobClientBase(JobClientBase, WithStateSync):

VERSION_TABLE_SCHEMA_COLUMNS: ClassVar[str] = "version_hash, schema_name, version, engine_version, inserted_at, schema"
STATE_TABLE_COLUMNS: ClassVar[str] = "version, engine_version, pipeline_name, state, created_at, _dlt_load_id"
_VERSION_TABLE_SCHEMA_COLUMNS: ClassVar[Tuple[str, ...]] = ('version_hash', 'schema_name', 'version', 'engine_version', 'inserted_at', 'schema')
_STATE_TABLE_COLUMNS: ClassVar[Tuple[str, ...]] = ('version', 'engine_version', 'pipeline_name', 'state', 'created_at', '_dlt_load_id')

def __init__(self, schema: Schema, config: DestinationClientConfiguration, sql_client: SqlClientBase[TNativeConn]) -> None:
self.version_table_schema_columns = ", ".join(sql_client.escape_column_name(col) for col in self._VERSION_TABLE_SCHEMA_COLUMNS)
self.state_table_columns = ", ".join(sql_client.escape_column_name(col) for col in self._STATE_TABLE_COLUMNS)

super().__init__(schema, config)
self.sql_client = sql_client
assert isinstance(config, DestinationClientDwhConfiguration)
Expand Down Expand Up @@ -256,13 +259,13 @@ def _from_db_type(cls, db_type: str, precision: Optional[int], scale: Optional[i

def get_stored_schema(self) -> StorageSchemaInfo:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
query = f"SELECT {self.VERSION_TABLE_SCHEMA_COLUMNS} FROM {name} WHERE schema_name = %s ORDER BY inserted_at DESC;"
query = f"SELECT {self.version_table_schema_columns} FROM {name} WHERE schema_name = %s ORDER BY inserted_at DESC;"
return self._row_to_schema_info(query, self.schema.name)

def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.sql_client.make_qualified_table_name(self.schema.state_table_name)
loads_table = self.sql_client.make_qualified_table_name(self.schema.loads_table_name)
query = f"SELECT {self.STATE_TABLE_COLUMNS} FROM {state_table} AS s JOIN {loads_table} AS l ON l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY created_at DESC"
query = f"SELECT {self.state_table_columns} FROM {state_table} AS s JOIN {loads_table} AS l ON l.load_id = s._dlt_load_id WHERE pipeline_name = %s AND l.status = 0 ORDER BY created_at DESC"
with self.sql_client.execute_query(query, pipeline_name) as cur:
row = cur.fetchone()
if not row:
Expand All @@ -280,7 +283,7 @@ def get_stored_state(self, pipeline_name: str) -> StateInfo:

def get_stored_schema_by_hash(self, version_hash: str) -> StorageSchemaInfo:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
query = f"SELECT {self.VERSION_TABLE_SCHEMA_COLUMNS} FROM {name} WHERE version_hash = %s;"
query = f"SELECT {self.version_table_schema_columns} FROM {name} WHERE version_hash = %s;"
return self._row_to_schema_info(query, version_hash)

def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTables:
Expand Down Expand Up @@ -429,7 +432,7 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None:
name = self.sql_client.make_qualified_table_name(self.schema.version_table_name)
# values = schema.version_hash, schema.name, schema.version, schema.ENGINE_VERSION, str(now_ts), schema_str
self.sql_client.execute_sql(
f"INSERT INTO {name}({self.VERSION_TABLE_SCHEMA_COLUMNS}) VALUES (%s, %s, %s, %s, %s, %s);", schema.stored_version_hash, schema.name, schema.version, schema.ENGINE_VERSION, now_ts, schema_str
f"INSERT INTO {name}({self.version_table_schema_columns}) VALUES (%s, %s, %s, %s, %s, %s);", schema.stored_version_hash, schema.name, schema.version, schema.ENGINE_VERSION, now_ts, schema_str
)


Expand Down
5 changes: 5 additions & 0 deletions dlt/destinations/mssql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# loader account setup

1. Create new database `CREATE DATABASE dlt_data`
2. Create new user, set password `CREATE USER loader WITH PASSWORD 'loader';`
3. Set as database owner (we could set lower permission) `ALTER DATABASE dlt_data OWNER TO loader`
51 changes: 51 additions & 0 deletions dlt/destinations/mssql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Type

from dlt.common.schema.schema import Schema
from dlt.common.configuration import with_config, known_sections
from dlt.common.configuration.accessors import config
from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration
from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE
from dlt.common.wei import EVM_DECIMAL_PRECISION

from dlt.destinations.mssql.configuration import MsSqlClientConfiguration


@with_config(spec=MsSqlClientConfiguration, sections=(known_sections.DESTINATION, "mssql",))
def _configure(config: MsSqlClientConfiguration = config.value) -> MsSqlClientConfiguration:
return config


def capabilities() -> DestinationCapabilitiesContext:
caps = DestinationCapabilitiesContext()
caps.preferred_loader_file_format = "insert_values"
caps.supported_loader_file_formats = ["insert_values"]
caps.preferred_staging_file_format = None
caps.supported_staging_file_formats = []
caps.escape_identifier = escape_postgres_identifier
caps.escape_literal = escape_mssql_literal
caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE)
caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0)
# https://learn.microsoft.com/en-us/sql/sql-server/maximum-capacity-specifications-for-sql-server?view=sql-server-ver16&redirectedfrom=MSDN
caps.max_identifier_length = 128
caps.max_column_identifier_length = 128
caps.max_query_length = 4 * 1024 * 64 * 1024
caps.is_max_query_length_in_bytes = True
caps.max_text_data_type_length = 2 ** 30 - 1
caps.is_max_text_data_type_length_in_bytes = False
caps.supports_ddl_transactions = True
caps.max_rows_per_insert = 1000

return caps


def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase:
# import client when creating instance so capabilities and config specs can be accessed without dependencies installed
from dlt.destinations.mssql.mssql import MsSqlClient

return MsSqlClient(schema, _configure(initial_config)) # type: ignore[arg-type]


def spec() -> Type[DestinationClientConfiguration]:
return MsSqlClientConfiguration
Loading

0 comments on commit e8fecb7

Please sign in to comment.