-
Notifications
You must be signed in to change notification settings - Fork 199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: #2166 databricks direct loading #2219
base: devel
Are you sure you want to change the base?
Changes from all commits
9d560d9
902c49d
1efe565
b60b3d3
e772d20
2bd0be0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec | ||
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration | ||
from dlt.common.configuration.exceptions import ConfigurationValueError | ||
|
||
from dlt.common import logger | ||
|
||
DATABRICKS_APPLICATION_ID = "dltHub_dlt" | ||
|
||
|
@@ -15,6 +15,7 @@ class DatabricksCredentials(CredentialsConfiguration): | |
catalog: str = None | ||
server_hostname: str = None | ||
http_path: str = None | ||
direct_load: bool = False | ||
access_token: Optional[TSecretStrValue] = None | ||
client_id: Optional[TSecretStrValue] = None | ||
client_secret: Optional[TSecretStrValue] = None | ||
|
@@ -37,10 +38,23 @@ class DatabricksCredentials(CredentialsConfiguration): | |
|
||
def on_resolved(self) -> None: | ||
if not ((self.client_id and self.client_secret) or self.access_token): | ||
raise ConfigurationValueError( | ||
"No valid authentication method detected. Provide either 'client_id' and" | ||
" 'client_secret' for OAuth, or 'access_token' for token-based authentication." | ||
) | ||
# databricks authentication: get context config | ||
from databricks.sdk import WorkspaceClient | ||
|
||
w = WorkspaceClient() | ||
notebook_context = w.dbutils.notebook.entry_point.getDbutils().notebook().getContext() | ||
self.access_token = notebook_context.apiToken().getOrElse(None) | ||
|
||
self.server_hostname = notebook_context.browserHostName().getOrElse(None) | ||
|
||
if not self.access_token or not self.server_hostname: | ||
raise ConfigurationValueError( | ||
"No valid authentication method detected. Provide either 'client_id' and" | ||
" 'client_secret' for OAuth, or 'access_token' for token-based authentication," | ||
" and the server_hostname." | ||
) | ||
|
||
self.direct_load = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think we need this. if we have a local file we do direct load. we do not need to be in a notebook context to do it. just the default access token needs notebook |
||
|
||
def to_connector_params(self) -> Dict[str, Any]: | ||
conn_params = dict( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,7 @@ | |
from dlt.destinations.sql_jobs import SqlMergeFollowupJob | ||
from dlt.destinations.job_impl import ReferenceFollowupJobRequest | ||
from dlt.destinations.utils import is_compression_disabled | ||
|
||
from dlt.common.utils import uniq_id | ||
|
||
SUPPORTED_BLOB_STORAGE_PROTOCOLS = AZURE_BLOB_STORAGE_PROTOCOLS + S3_PROTOCOLS + GCS_PROTOCOLS | ||
|
||
|
@@ -50,126 +50,214 @@ def __init__( | |
self._staging_config = staging_config | ||
self._job_client: "DatabricksClient" = None | ||
|
||
self._sql_client = None | ||
|
||
def run(self) -> None: | ||
self._sql_client = self._job_client.sql_client | ||
|
||
qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) | ||
staging_credentials = self._staging_config.credentials | ||
# extract and prepare some vars | ||
|
||
# decide if this is a local file or a staged file | ||
is_local_file = not ReferenceFollowupJobRequest.is_reference_job(self._file_path) | ||
if is_local_file and self._job_client.config.credentials.direct_load: | ||
# local file by uploading to a temporary volume on Databricks | ||
from_clause, file_name = self._handle_local_file_upload(self._file_path) | ||
credentials_clause = "" | ||
orig_bucket_path = None # not used for local file | ||
else: | ||
# staged file | ||
from_clause, credentials_clause, file_name, orig_bucket_path = ( | ||
self._handle_staged_file() | ||
) | ||
|
||
# Determine the source format and any additional format options | ||
source_format, format_options_clause, skip_load = self._determine_source_format( | ||
file_name, orig_bucket_path | ||
) | ||
|
||
if skip_load: | ||
# If the file is empty or otherwise un-loadable, exit early | ||
return | ||
|
||
statement = self._build_copy_into_statement( | ||
qualified_table_name, | ||
from_clause, | ||
credentials_clause, | ||
source_format, | ||
format_options_clause, | ||
) | ||
|
||
self._sql_client.execute_sql(statement) | ||
|
||
def _handle_local_file_upload(self, local_file_path: str) -> tuple[str, str]: | ||
from databricks.sdk import WorkspaceClient | ||
import time | ||
import io | ||
|
||
w: WorkspaceClient | ||
|
||
credentials = self._job_client.config.credentials | ||
if credentials.client_id and credentials.client_secret: | ||
# oauth authentication | ||
w = WorkspaceClient( | ||
host=credentials.server_hostname, | ||
client_id=credentials.client_id, | ||
client_secret=credentials.client_secret, | ||
) | ||
elif credentials.access_token: | ||
# token authentication | ||
w = WorkspaceClient( | ||
host=credentials.server_hostname, | ||
token=credentials.access_token, | ||
) | ||
|
||
file_name = FileStorage.get_file_name_from_file_path(local_file_path) | ||
file_format = "" | ||
if file_name.endswith(".parquet"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think that you need to know the file format here. just upload a file we have. it has proper extension. also keep the |
||
file_format = "parquet" | ||
elif file_name.endswith(".jsonl"): | ||
file_format = "jsonl" | ||
else: | ||
return "", file_name | ||
|
||
volume_path = f"/Volumes/{self._sql_client.database_name}/{self._sql_client.dataset_name}/{self._sql_client.volume_name}/{time.time_ns()}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how IMO we should handle volumes:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree about the volume_name I'll answer the path (time_ns) and file_name in another comment |
||
volume_file_name = ( # replace file_name for random hex code - databricks loading fails when file_name starts with - or . | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why the same name as in |
||
f"{uniq_id()}.{file_format}" | ||
) | ||
volume_file_path = f"{volume_path}/{volume_file_name}" | ||
|
||
with open(local_file_path, "rb") as f: | ||
file_bytes = f.read() | ||
binary_data = io.BytesIO(file_bytes) | ||
w.files.upload(volume_file_path, binary_data, overwrite=True) | ||
|
||
from_clause = f"FROM '{volume_path}'" | ||
|
||
return from_clause, file_name | ||
|
||
def _handle_staged_file(self) -> tuple[str, str, str, str]: | ||
bucket_path = orig_bucket_path = ( | ||
ReferenceFollowupJobRequest.resolve_reference(self._file_path) | ||
if ReferenceFollowupJobRequest.is_reference_job(self._file_path) | ||
else "" | ||
) | ||
file_name = ( | ||
FileStorage.get_file_name_from_file_path(bucket_path) | ||
if bucket_path | ||
else self._file_name | ||
) | ||
from_clause = "" | ||
credentials_clause = "" | ||
format_options_clause = "" | ||
|
||
if bucket_path: | ||
bucket_url = urlparse(bucket_path) | ||
bucket_scheme = bucket_url.scheme | ||
if not bucket_path: | ||
raise LoadJobTerminalException( | ||
self._file_path, | ||
"Cannot load from local file. Databricks does not support loading from local files." | ||
" Configure staging with an s3, azure or google storage bucket.", | ||
) | ||
|
||
if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: | ||
raise LoadJobTerminalException( | ||
self._file_path, | ||
f"Databricks cannot load data from staging bucket {bucket_path}. Only s3, azure" | ||
" and gcs buckets are supported. Please note that gcs buckets are supported" | ||
" only via named credential", | ||
) | ||
file_name = FileStorage.get_file_name_from_file_path(bucket_path) | ||
|
||
if self._job_client.config.is_staging_external_location: | ||
# just skip the credentials clause for external location | ||
# https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location | ||
pass | ||
elif self._job_client.config.staging_credentials_name: | ||
# add named credentials | ||
credentials_clause = ( | ||
f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" | ||
) | ||
else: | ||
# referencing an staged files via a bucket URL requires explicit AWS credentials | ||
if bucket_scheme == "s3": | ||
assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) | ||
s3_creds = staging_credentials.to_session_credentials() | ||
credentials_clause = f"""WITH(CREDENTIAL( | ||
staging_credentials = self._staging_config.credentials | ||
bucket_url = urlparse(bucket_path) | ||
bucket_scheme = bucket_url.scheme | ||
|
||
if bucket_scheme not in SUPPORTED_BLOB_STORAGE_PROTOCOLS: | ||
raise LoadJobTerminalException( | ||
self._file_path, | ||
f"Databricks cannot load data from staging bucket {bucket_path}. " | ||
"Only s3, azure and gcs buckets are supported. " | ||
"Please note that gcs buckets are supported only via named credential.", | ||
) | ||
|
||
credentials_clause = "" | ||
|
||
if self._job_client.config.is_staging_external_location: | ||
# skip the credentials clause | ||
pass | ||
elif self._job_client.config.staging_credentials_name: | ||
# named credentials | ||
credentials_clause = ( | ||
f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" | ||
) | ||
else: | ||
if bucket_scheme == "s3": | ||
assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) | ||
s3_creds = staging_credentials.to_session_credentials() | ||
credentials_clause = f"""WITH(CREDENTIAL( | ||
AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', | ||
AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', | ||
|
||
AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' | ||
)) | ||
""" | ||
elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: | ||
assert isinstance( | ||
staging_credentials, AzureCredentialsWithoutDefaults | ||
), "AzureCredentialsWithoutDefaults required to pass explicit credential" | ||
# Explicit azure credentials are needed to load from bucket without a named stage | ||
credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" | ||
bucket_path = self.ensure_databricks_abfss_url( | ||
bucket_path, | ||
staging_credentials.azure_storage_account_name, | ||
staging_credentials.azure_account_host, | ||
) | ||
else: | ||
raise LoadJobTerminalException( | ||
self._file_path, | ||
"You need to use Databricks named credential to use google storage." | ||
" Passing explicit Google credentials is not supported by Databricks.", | ||
) | ||
|
||
if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: | ||
))""" | ||
elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: | ||
assert isinstance( | ||
staging_credentials, | ||
( | ||
AzureCredentialsWithoutDefaults, | ||
AzureServicePrincipalCredentialsWithoutDefaults, | ||
), | ||
) | ||
staging_credentials, AzureCredentialsWithoutDefaults | ||
), "AzureCredentialsWithoutDefaults required to pass explicit credential" | ||
credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" | ||
bucket_path = self.ensure_databricks_abfss_url( | ||
bucket_path, | ||
staging_credentials.azure_storage_account_name, | ||
staging_credentials.azure_account_host, | ||
) | ||
else: | ||
raise LoadJobTerminalException( | ||
self._file_path, | ||
"You need to use Databricks named credential to use google storage." | ||
" Passing explicit Google credentials is not supported by Databricks.", | ||
) | ||
|
||
# always add FROM clause | ||
from_clause = f"FROM '{bucket_path}'" | ||
else: | ||
raise LoadJobTerminalException( | ||
self._file_path, | ||
"Cannot load from local file. Databricks does not support loading from local files." | ||
" Configure staging with an s3, azure or google storage bucket.", | ||
if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: | ||
assert isinstance( | ||
staging_credentials, | ||
(AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults), | ||
) | ||
bucket_path = self.ensure_databricks_abfss_url( | ||
bucket_path, | ||
staging_credentials.azure_storage_account_name, | ||
staging_credentials.azure_account_host, | ||
) | ||
|
||
from_clause = f"FROM '{bucket_path}'" | ||
|
||
# decide on source format, stage_file_path will either be a local file or a bucket path | ||
return from_clause, credentials_clause, file_name, orig_bucket_path | ||
|
||
def _determine_source_format( | ||
self, file_name: str, orig_bucket_path: str | ||
) -> tuple[str, str, bool]: | ||
if file_name.endswith(".parquet"): | ||
source_format = "PARQUET" # Only parquet is supported | ||
return "PARQUET", "", False | ||
|
||
elif file_name.endswith(".jsonl"): | ||
if not is_compression_disabled(): | ||
raise LoadJobTerminalException( | ||
self._file_path, | ||
"Databricks loader does not support gzip compressed JSON files. Please disable" | ||
" compression in the data writer configuration:" | ||
"Databricks loader does not support gzip compressed JSON files. " | ||
"Please disable compression in the data writer configuration:" | ||
" https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", | ||
) | ||
source_format = "JSON" | ||
|
||
format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" | ||
# Databricks fails when trying to load empty json files, so we have to check the file size | ||
|
||
# check for an empty JSON file | ||
fs, _ = fsspec_from_config(self._staging_config) | ||
file_size = fs.size(orig_bucket_path) | ||
if file_size == 0: # Empty file, do nothing | ||
return | ||
if orig_bucket_path is not None: | ||
file_size = fs.size(orig_bucket_path) | ||
if file_size == 0: | ||
return "JSON", format_options_clause, True | ||
|
||
return "JSON", format_options_clause, False | ||
|
||
raise LoadJobTerminalException( | ||
self._file_path, "Databricks loader only supports .parquet or .jsonl file extensions." | ||
) | ||
|
||
statement = f"""COPY INTO {qualified_table_name} | ||
def _build_copy_into_statement( | ||
self, | ||
qualified_table_name: str, | ||
from_clause: str, | ||
credentials_clause: str, | ||
source_format: str, | ||
format_options_clause: str, | ||
) -> str: | ||
return f"""COPY INTO {qualified_table_name} | ||
{from_clause} | ||
{credentials_clause} | ||
FILEFORMAT = {source_format} | ||
{format_options_clause} | ||
""" | ||
self._sql_client.execute_sql(statement) | ||
""" | ||
|
||
@staticmethod | ||
def ensure_databricks_abfss_url( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code is correct but you must handle the situation when default credentials do not exist (ie. outside of notebook). I get this exception in this case:
just skip the code that assign values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, I haven't tested yet on the notebook, and you are also right about that exception context - I have to catch that exception, and add some tests.