From 3548ac288d345bcd3d378abf8904e23bac9ffc77 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Fri, 15 Nov 2024 14:51:21 +0300 Subject: [PATCH] Ability to pass credentials by string --- ydb_dbapi/connections.py | 17 ++++++++++++----- ydb_dbapi/utils.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index da3e436..8f45318 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -20,6 +20,7 @@ from .errors import NotSupportedError from .utils import handle_ydb_errors from .utils import maybe_get_current_trace_id +from .utils import prepare_credentials class IsolationLevel(str, Enum): @@ -69,13 +70,15 @@ def __init__( port: str = "", database: str = "", ydb_table_path_prefix: str = "", - credentials: ydb.AbstractCredentials | None = None, + protocol: str | None = None, + credentials: ydb.Credentials | dict | str | None = None, ydb_session_pool: SessionPool | AsyncSessionPool | None = None, **kwargs: dict, ) -> None: - self.endpoint = f"grpc://{host}:{port}" + protocol = protocol if protocol else "grpc" + self.endpoint = f"{protocol}://{host}:{port}" + self.credentials = prepare_credentials(credentials) self.database = database - self.credentials = credentials self.table_path_prefix = ydb_table_path_prefix self.connection_kwargs: dict = kwargs @@ -170,7 +173,8 @@ def __init__( port: str = "", database: str = "", ydb_table_path_prefix: str = "", - credentials: ydb.AbstractCredentials | None = None, + protocol: str | None = None, + credentials: ydb.Credentials | None = None, ydb_session_pool: SessionPool | AsyncSessionPool | None = None, **kwargs: dict, ) -> None: @@ -179,6 +183,7 @@ def __init__( port=port, database=database, ydb_table_path_prefix=ydb_table_path_prefix, + protocol=protocol, credentials=credentials, ydb_session_pool=ydb_session_pool, **kwargs, @@ -333,7 +338,8 @@ def __init__( port: str = "", database: str = "", ydb_table_path_prefix: str = "", - credentials: ydb.AbstractCredentials | None = None, + protocol: str | None = None, + credentials: ydb.Credentials | None = None, ydb_session_pool: SessionPool | AsyncSessionPool | None = None, **kwargs: dict, ) -> None: @@ -342,6 +348,7 @@ def __init__( port=port, database=database, ydb_table_path_prefix=ydb_table_path_prefix, + protocol=protocol, credentials=credentials, ydb_session_pool=ydb_session_pool, **kwargs, diff --git a/ydb_dbapi/utils.py b/ydb_dbapi/utils.py index 38f964f..226bb4e 100644 --- a/ydb_dbapi/utils.py +++ b/ydb_dbapi/utils.py @@ -2,6 +2,7 @@ import functools import importlib.util +import json from enum import Enum from inspect import iscoroutinefunction from typing import Any @@ -117,3 +118,30 @@ def maybe_get_current_trace_id() -> str | None: # Return None if OpenTelemetry is not available or trace ID is invalid return None + + +def prepare_credentials( + credentials: ydb.Credentials | dict | str | None, +) -> ydb.Credentials | None: + if not credentials: + return None + + if isinstance(credentials, ydb.Credentials): + return credentials + + if isinstance(credentials, str): + credentials = json.loads(credentials) + + if isinstance(credentials, dict): + credentials = credentials or {} + token = credentials.get("token") + if token: + return ydb.AccessTokenCredentials(token) + + service_account_json = credentials.get("service_account_json") + if service_account_json: + return ydb.iam.ServiceAccountCredentials.from_content( + json.dumps(service_account_json) + ) + + return ydb.AnonymousCredentials()