Skip to content

Commit

Permalink
Improve DRYness #1055
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Mar 23, 2024
1 parent 9d905ec commit 86685fd
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 79 deletions.
18 changes: 9 additions & 9 deletions dlt/destinations/impl/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
from copy import deepcopy
from typing import ClassVar, Optional, Dict, List, Sequence, cast
Expand Down Expand Up @@ -30,7 +29,9 @@
from dlt.destinations.impl.clickhouse.sql_client import ClickhouseSqlClient
from dlt.destinations.impl.clickhouse.utils import (
convert_storage_to_http_scheme,
render_s3_table_function,
render_object_storage_table_function,
FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING,
SUPPORTED_FILE_FORMATS,
)
from dlt.destinations.job_client_impl import (
SqlJobClientWithStaging,
Expand Down Expand Up @@ -130,22 +131,22 @@ def __init__(
bucket_url = urlparse(bucket_path)
bucket_scheme = bucket_url.scheme

file_extension = cast(SUPPORTED_FILE_FORMATS, file_extension)
table_function: str

if bucket_scheme in ("s3", "gs", "gcs"):
bucket_http_url = convert_storage_to_http_scheme(bucket_url)

table_function = (
render_s3_table_function(
render_object_storage_table_function(
bucket_http_url,
staging_credentials.aws_secret_access_key,
staging_credentials.aws_secret_access_key,
file_format=file_extension, # type: ignore[arg-type]
file_format=file_extension,
)
if isinstance(staging_credentials, AwsCredentialsWithoutDefaults)
else render_s3_table_function(
bucket_http_url,
file_format=file_extension, # type: ignore[arg-type]
else render_object_storage_table_function(
bucket_http_url, file_format=file_extension
)
)
elif bucket_scheme in ("az", "abfs"):
Expand All @@ -159,8 +160,7 @@ def __init__(
container_name = bucket_url.netloc
blobpath = bucket_url.path

format_mapping = {"jsonl": "JSONEachRow", "parquet": "Parquet"}
clickhouse_format = format_mapping[file_extension]
clickhouse_format = FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING[file_extension]

table_function = (
f"azureBlobStorage('{storage_account_url}','{container_name}','{ blobpath }','{ account_name }','{ account_key }','{ clickhouse_format}')"
Expand Down
53 changes: 43 additions & 10 deletions dlt/destinations/impl/clickhouse/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
Optional,
Sequence,
ClassVar,
Union,
)

import clickhouse_driver # type: ignore[import-untyped]
import clickhouse_driver.errors # type: ignore[import-untyped]
from clickhouse_driver.dbapi import Connection # type: ignore[import-untyped]
from clickhouse_driver.dbapi.extras import DictCursor # type: ignore[import-untyped]

from dlt.common.destination import DestinationCapabilitiesContext
Expand Down Expand Up @@ -54,7 +52,15 @@ def __init__(self, dataset_name: str, credentials: ClickhouseCredentials) -> Non
self.database_name = credentials.database

def open_connection(self) -> clickhouse_driver.dbapi.connection.Connection:
self._conn = clickhouse_driver.connect(dsn=self.credentials.to_native_representation())
self._conn = clickhouse_driver.dbapi.connect(
dsn=self.credentials.to_native_representation()
)
with self._conn.cursor() as cur:
# Set session settings. There doesn't seem to be a way to set these
# without using the library's top-level, non-dbapi2 client.
cur.execute("set allow_experimental_object_type = 1")
cur.execute("set allow_experimental_lightweight_delete = 1")

return self._conn

@raise_open_connection_error
Expand Down Expand Up @@ -89,6 +95,29 @@ def execute_sql(
with self.execute_query(sql, *args, **kwargs) as curr:
return None if curr.description is None else curr.fetchall()

def create_dataset(self) -> None:
# Clickhouse doesn't have schemas.
pass

def drop_dataset(self) -> None:
# Since Clickhouse doesn't have schemas, we need to drop all tables in our virtual schema,
# or collection of tables that has the `dataset_name` as a prefix.
to_drop_results = self.execute_sql(
"""
SELECT name
FROM system.tables
WHERE database = %(db_name)s
AND name LIKE %(dataset_name)s
""",
{"db_name": self.database_name, "dataset_name": self.dataset_name},
)
for to_drop_result in to_drop_results:
table = to_drop_result[0]
self.execute_sql(
"DROP TABLE %(database)s.%(table)s SYNC",
{"database": self.database_name, "table": table},
)

@contextmanager
@raise_database_error
def execute_query(
Expand All @@ -97,7 +126,6 @@ def execute_query(
cur: clickhouse_driver.dbapi.connection.Cursor
with self._conn.cursor() as cur:
try:
# TODO: Clickhouse driver only accepts pyformat `...WHERE name=%(name)s` parameter marker arguments.
cur.execute(query, args or (kwargs or None))
yield ClickhouseDBApiCursorImpl(cur) # type: ignore[abstract]
except clickhouse_driver.dbapi.Error:
Expand Down Expand Up @@ -143,7 +171,7 @@ def _make_database_exception(cls, ex: Exception) -> Exception: # type: ignore[r
clickhouse_driver.dbapi.errors.InternalError,
),
):
if term := cls._maybe_make_terminal_exception_from_data_error(ex):
if term := cls._maybe_make_terminal_exception_from_data_error():
return term
else:
return DatabaseTransientException(ex)
Expand All @@ -161,12 +189,17 @@ def _make_database_exception(cls, ex: Exception) -> Exception: # type: ignore[r
else:
return ex

def has_dataset(self) -> bool:
query = """
SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE
catalog_name = %(database)s AND schema_name = %(table)s
"""
database, table = self.fully_qualified_dataset_name(escape=False).split(".", 2)
rows = self.execute_sql(query, {"database": database, "table": table})
return len(rows) > 0

@staticmethod
def _maybe_make_terminal_exception_from_data_error(
ex: Union[
clickhouse_driver.dbapi.errors.DataError, clickhouse_driver.dbapi.errors.InternalError
]
) -> Optional[Exception]:
def _maybe_make_terminal_exception_from_data_error() -> Optional[Exception]:
return None

@staticmethod
Expand Down
17 changes: 10 additions & 7 deletions dlt/destinations/impl/clickhouse/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Union, Optional, Literal
from urllib.parse import urlparse, ParseResult, urlunparse
from typing import Union, Optional, Literal, Dict
from urllib.parse import urlparse, ParseResult

from jinja2 import Template


S3_TABLE_FUNCTION_FILE_FORMATS = Literal["jsonl", "parquet"]
SUPPORTED_FILE_FORMATS = Literal["jsonl", "parquet"]
FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING: Dict[SUPPORTED_FILE_FORMATS, str] = {
"jsonl": "JSONEachRow",
"parquet": "Parquet",
}


def convert_storage_to_http_scheme(
Expand Down Expand Up @@ -40,17 +44,16 @@ def convert_storage_to_http_scheme(
raise Exception(f"Error converting storage URL to HTTP protocol: '{url}'") from e


def render_s3_table_function(
def render_object_storage_table_function(
url: str,
access_key_id: Optional[str] = None,
secret_access_key: Optional[str] = None,
file_format: Optional[S3_TABLE_FUNCTION_FILE_FORMATS] = "jsonl",
file_format: SUPPORTED_FILE_FORMATS = "jsonl",
) -> str:
if file_format not in ["parquet", "jsonl"]:
raise ValueError("Clickhouse s3/gcs staging only supports 'parquet' and 'jsonl'.")

format_mapping = {"jsonl": "JSONEachRow", "parquet": "Parquet"}
clickhouse_format = format_mapping[file_format]
clickhouse_format = FILE_FORMAT_TO_TABLE_FUNCTION_MAPPING[file_format]

template = Template(
"""s3('{{ url }}'{% if access_key_id and secret_access_key %},'{{ access_key_id }}','{{ secret_access_key }}'{% else %},NOSIGN{% endif %},'{{ clickhouse_format }}')"""
Expand Down
4 changes: 1 addition & 3 deletions dlt/helpers/streamlit_app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def render_with_pipeline(render_func: Callable[..., None]) -> None:
render_func(pipeline)


def query_using_cache(
pipeline: dlt.Pipeline, ttl: int
) -> Callable[..., Optional[pd.DataFrame]]:
def query_using_cache(pipeline: dlt.Pipeline, ttl: int) -> Callable[..., Optional[pd.DataFrame]]:
@st.cache_data(ttl=ttl)
def do_query( # type: ignore[return]
query: str,
Expand Down
12 changes: 5 additions & 7 deletions docs/examples/chess_production/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dlt.common.typing import StrAny, TDataItems
from dlt.sources.helpers.requests import client


@dlt.source
def chess(
chess_url: str = dlt.config.value,
Expand Down Expand Up @@ -56,6 +57,7 @@ def players_games(username: Any) -> Iterator[TDataItems]:

MAX_PLAYERS = 5


def load_data_with_retry(pipeline, data):
try:
for attempt in Retrying(
Expand All @@ -65,9 +67,7 @@ def load_data_with_retry(pipeline, data):
reraise=True,
):
with attempt:
logger.info(
f"Running the pipeline, attempt={attempt.retry_state.attempt_number}"
)
logger.info(f"Running the pipeline, attempt={attempt.retry_state.attempt_number}")
load_info = pipeline.run(data)
logger.info(str(load_info))

Expand All @@ -89,9 +89,7 @@ def load_data_with_retry(pipeline, data):
# print the information on the first load package and all jobs inside
logger.info(f"First load package info: {load_info.load_packages[0]}")
# print the information on the first completed job in first load package
logger.info(
f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}"
)
logger.info(f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}")

# check for schema updates:
schema_updates = [p.schema_update for p in load_info.load_packages]
Expand Down Expand Up @@ -149,4 +147,4 @@ def load_data_with_retry(pipeline, data):
)
# get data for a few famous players
data = chess(chess_url="https://api.chess.com/pub/", max_players=MAX_PLAYERS)
load_data_with_retry(pipeline, data)
load_data_with_retry(pipeline, data)
2 changes: 2 additions & 0 deletions docs/examples/connector_x_arrow/load_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dlt
from dlt.sources.credentials import ConnectionStringCredentials


def read_sql_x(
conn_str: ConnectionStringCredentials = dlt.secrets.value,
query: str = dlt.config.value,
Expand All @@ -14,6 +15,7 @@ def read_sql_x(
protocol="binary",
)


def genome_resource():
# create genome resource with merge on `upid` primary key
genome = dlt.resource(
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/google_sheets/google_sheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
)
from dlt.common.typing import DictStrAny, StrAny


def _initialize_sheets(
credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials]
) -> Any:
# Build the service object.
service = build("sheets", "v4", credentials=credentials.to_native_credentials())
return service


@dlt.source
def google_spreadsheet(
spreadsheet_id: str,
Expand Down Expand Up @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]:
for name in sheet_names
]


if __name__ == "__main__":
pipeline = dlt.pipeline(destination="duckdb")
# see example.secrets.toml to where to put credentials
Expand All @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]:
sheet_names=range_names,
)
)
print(info)
print(info)
8 changes: 4 additions & 4 deletions docs/examples/incremental_loading/zendesk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from dlt.common.typing import TAnyDateTime
from dlt.sources.helpers.requests import client


@dlt.source(max_table_nesting=2)
def zendesk_support(
credentials: Dict[str, str] = dlt.secrets.value,
start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008
year=2000, month=1, day=1
),
start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008
end_date: Optional[TAnyDateTime] = None,
):
"""
Expand Down Expand Up @@ -113,11 +112,12 @@ def get_pages(
if not response_json["end_of_stream"]:
get_url = response_json["next_page"]


if __name__ == "__main__":
# create dlt pipeline
pipeline = dlt.pipeline(
pipeline_name="zendesk", destination="duckdb", dataset_name="zendesk_data"
)

load_info = pipeline.run(zendesk_support())
print(load_info)
print(load_info)
2 changes: 2 additions & 0 deletions docs/examples/nested_data/nested_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

CHUNK_SIZE = 10000


# You can limit how deep dlt goes when generating child tables.
# By default, the library will descend and generate child tables
# for all nested lists, without a limit.
Expand Down Expand Up @@ -81,6 +82,7 @@ def load_documents(self) -> Iterator[TDataItem]:
while docs_slice := list(islice(cursor, CHUNK_SIZE)):
yield map_nested_in_place(convert_mongo_objs, docs_slice)


def convert_mongo_objs(value: Any) -> Any:
if isinstance(value, (ObjectId, Decimal128)):
return str(value)
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/pdf_to_weaviate/pdf_to_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dlt.destinations.impl.weaviate import weaviate_adapter
from PyPDF2 import PdfReader


@dlt.resource(selected=False)
def list_files(folder_path: str):
folder_path = os.path.abspath(folder_path)
Expand All @@ -15,6 +16,7 @@ def list_files(folder_path: str):
"mtime": os.path.getmtime(file_path),
}


@dlt.transformer(primary_key="page_id", write_disposition="merge")
def pdf_to_text(file_item, separate_pages: bool = False):
if not separate_pages:
Expand All @@ -28,6 +30,7 @@ def pdf_to_text(file_item, separate_pages: bool = False):
page_item["page_id"] = file_item["file_name"] + "_" + str(page_no)
yield page_item


pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate")

# this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf"
Expand All @@ -51,4 +54,4 @@ def pdf_to_text(file_item, separate_pages: bool = False):

client = weaviate.Client("http://localhost:8080")
# get text of all the invoices in InvoiceText class we just created above
print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do())
print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do())
Loading

0 comments on commit 86685fd

Please sign in to comment.