From ab4b269ff8471372ed0ea9a241e5d225495fca53 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Mon, 19 Dec 2022 14:52:06 -0800 Subject: [PATCH 01/27] Change RedshiftConnectionManager to extend from SQLConnectionManager, define a _get_connect_method method to leverage Redshift python connector to retrieve the connect method --- dbt/adapters/redshift/connections.py | 199 +++++++++++++++++---------- dbt/adapters/redshift/impl.py | 3 +- setup.py | 1 + 3 files changed, 129 insertions(+), 74 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index be4d626d3..bcca18ddf 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -1,20 +1,23 @@ +import os from multiprocessing import Lock from contextlib import contextmanager -from typing import NewType +from typing import NewType, Any -from dbt.adapters.postgres import PostgresConnectionManager -from dbt.adapters.postgres import PostgresCredentials +from dbt.adapters.sql import SQLConnectionManager +from dbt.contracts.connection import AdapterResponse, Connection, Credentials from dbt.events import AdapterLogger import dbt.exceptions import dbt.flags - -import boto3 - +import redshift_connector +from dbt.exceptions import RuntimeException from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum from dataclasses import dataclass, field from typing import Optional, List +from dbt.helper_types import Port +from redshift_connector import OperationalError, DatabaseError, DataError + logger = AdapterLogger("Redshift") drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore @@ -38,7 +41,10 @@ class RedshiftConnectionMethod(StrEnum): @dataclass -class RedshiftCredentials(PostgresCredentials): +class RedshiftCredentials(Credentials): + host: str + user: str + port: Port method: str = RedshiftConnectionMethod.DATABASE # type: ignore password: Optional[str] = None # type: ignore cluster_id: Optional[str] = field( @@ -52,6 +58,14 @@ class RedshiftCredentials(PostgresCredentials): autocreate: bool = False db_groups: List[str] = field(default_factory=list) ra3_node: Optional[bool] = False + connect_timeout: int = 10 + role: Optional[str] = None + sslmode: Optional[str] = None + sslcert: Optional[str] = None + sslkey: Optional[str] = None + sslrootcert: Optional[str] = None + application_name: Optional[str] = "dbt" + retries: int = 1 @property def type(self): @@ -61,10 +75,59 @@ def _connection_keys(self): keys = super()._connection_keys() return keys + ("method", "cluster_id", "iam_profile", "iam_duration_seconds") + @property + def unique_field(self) -> str: + return self.host + -class RedshiftConnectionManager(PostgresConnectionManager): +class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" + def _get_backend_pid(self): + sql = "select pg_backend_pid()" + _, cursor = self.add_query(sql) + res = cursor.fetchone() + return res + + def cancel(self, connection: Connection): + connection_name = connection.name + try: + pid = self._get_backend_pid() + sql = "select pg_terminate_backend({})".format(pid) + _, cursor = self.add_query(sql) + res = cursor.fetchone() + logger.debug("Cancel query '{}': {}".format(connection_name, res)) + except redshift_connector.error.InterfaceError as e: + if "is closed" in str(e): + logger.debug(f"Connection {connection_name} was already closed") + return + raise + + @classmethod + def get_response(cls, cursor: Any) -> AdapterResponse: + message = str(cursor.statusmessage) + rows = cursor.rowcount + status_message_parts = message.split() if message is not None else [] + status_message_strings = [part for part in status_message_parts if not part.isdigit()] + code = " ".join(status_message_strings) + return AdapterResponse(_message=message, code=code, rows_affected=rows) + + @contextmanager + def exception_handler(self, sql): + try: + yield + except redshift_connector.error.Error as e: + logger.debug(f"Redshift error: {str(e)}") + self.rollback_if_open() + except Exception as e: + logger.debug("Error running SQL: {}", sql) + logger.debug("Rolling back transaction.") + self.rollback_if_open() + # Raise DBT native exceptions as is. + if isinstance(e, dbt.exceptions.Exception): + raise + raise RuntimeException(str(e)) from e + @contextmanager def fresh_transaction(self, name=None): """On entrance to this context manager, hold an exclusive lock and @@ -88,68 +151,9 @@ def fresh_transaction(self, name=None): self.commit() self.begin() - @classmethod - def fetch_cluster_credentials( - cls, db_user, db_name, cluster_id, iam_profile, duration_s, autocreate, db_groups - ): - """Fetches temporary login credentials from AWS. The specified user - must already exist in the database, or else an error will occur""" - - if iam_profile is None: - session = boto3.Session() - boto_client = session.client("redshift") - else: - logger.debug("Connecting to Redshift using 'IAM'" + f"with profile {iam_profile}") - boto_session = boto3.Session(profile_name=iam_profile) - boto_client = boto_session.client("redshift") - - try: - return boto_client.get_cluster_credentials( - DbUser=db_user, - DbName=db_name, - ClusterIdentifier=cluster_id, - DurationSeconds=duration_s, - AutoCreate=autocreate, - DbGroups=db_groups, - ) - - except boto_client.exceptions.ClientError as e: - raise dbt.exceptions.FailedToConnectException( - "Unable to get temporary Redshift cluster credentials: {}".format(e) - ) - - @classmethod - def get_tmp_iam_cluster_credentials(cls, credentials): - cluster_id = credentials.cluster_id - - # default via: - # boto3.readthedocs.io/en/latest/reference/services/redshift.html - iam_duration_s = credentials.iam_duration_seconds - - if not cluster_id: - raise dbt.exceptions.FailedToConnectException( - "'cluster_id' must be provided in profile if IAM " "authentication method selected" - ) - - cluster_creds = cls.fetch_cluster_credentials( - credentials.user, - credentials.database, - credentials.cluster_id, - credentials.iam_profile, - iam_duration_s, - credentials.autocreate, - credentials.db_groups, - ) - - # replace username and password with temporary redshift credentials - return credentials.replace( - user=cluster_creds.get("DbUser"), password=cluster_creds.get("DbPassword") - ) - - @classmethod - def get_credentials(cls, credentials): + @staticmethod + def _get_connect_method(credentials): method = credentials.method - # Support missing 'method' for backwards compatibility if method == "database" or method is None: logger.debug("Connecting to Redshift using 'database' credentials") @@ -159,13 +163,64 @@ def get_credentials(cls, credentials): raise dbt.exceptions.FailedToConnectException( "'password' field is required for 'database' credentials" ) - return credentials + + def connect(): + c = redshift_connector.connect( + host=credentials.host, + database=credentials.database, + user=credentials.user, + password=credentials.password, + port=credentials.port if credentials.port else 5439, + ) + if credentials.role: + c.cursor().execute("set role {}".format(credentials.role)) + return c + + return connect elif method == "iam": - logger.debug("Connecting to Redshift using 'IAM' credentials") - return cls.get_tmp_iam_cluster_credentials(credentials) + def connect(): + c = redshift_connector.connect( + iam=True, + database=credentials.database, + db_user=credentials.user, + password="", + user="", + cluster_identifier=credentials.cluster_id, + access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + session_token=os.environ["AWS_SESSION_TOKEN"], + region=credentials.host.split(".")[2], + ) + if credentials.role: + c.cursor().execute("set role {}".format(credentials.role)) + return c + + return connect else: raise dbt.exceptions.FailedToConnectException( "Invalid 'method' in profile: '{}'".format(method) ) + + @classmethod + def open(cls, connection): + if connection.state == "open": + logger.debug("Connection is already open, skipping open.") + return connection + + credentials = connection.credentials + + def exponential_backoff(attempt: int): + return attempt * attempt + + retryable_exceptions = [OperationalError, DatabaseError, DataError] + + return cls.retry_connection( + connection, + connect=cls._get_connect_method(credentials), + logger=logger, + retry_limit=credentials.retries, + retry_timeout=exponential_backoff, + retryable_exceptions=retryable_exceptions, + ) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 206185f57..7d069938c 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -3,7 +3,6 @@ from dbt.adapters.base.impl import AdapterConfig from dbt.adapters.sql import SQLAdapter from dbt.adapters.base.meta import available -from dbt.adapters.postgres import PostgresAdapter from dbt.adapters.redshift import RedshiftConnectionManager from dbt.adapters.redshift import RedshiftColumn from dbt.adapters.redshift import RedshiftRelation @@ -22,7 +21,7 @@ class RedshiftConfig(AdapterConfig): backup: Optional[bool] = True -class RedshiftAdapter(PostgresAdapter, SQLAdapter): +class RedshiftAdapter(SQLAdapter): Relation = RedshiftRelation ConnectionManager = RedshiftConnectionManager Column = RedshiftColumn # type: ignore diff --git a/setup.py b/setup.py index 1049527de..a78a1a1e1 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ def _get_dbt_core_version(): "dbt-postgres~={}".format(dbt_core_version), # the following are all to match snowflake-connector-python "boto3>=1.4.4,<2.0.0", + "redshift-connector", ], zip_safe=False, classifiers=[ From ff9fdfd73373c2e8295cfb36feda4f41cf9b9b82 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Thu, 29 Dec 2022 16:53:37 -0800 Subject: [PATCH 02/27] Add/fix unit tests, create RedshiftConnectMethodFactory to vend connect_method --- dbt/adapters/redshift/connections.py | 209 +++++++++++++++------- tests/unit/test_redshift_adapter.py | 254 +++++++++++++-------------- 2 files changed, 266 insertions(+), 197 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index bcca18ddf..e657a3750 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -1,8 +1,9 @@ import os from multiprocessing import Lock from contextlib import contextmanager -from typing import NewType, Any +from typing import NewType +import boto3 from dbt.adapters.sql import SQLConnectionManager from dbt.contracts.connection import AdapterResponse, Connection, Credentials from dbt.events import AdapterLogger @@ -53,8 +54,8 @@ class RedshiftCredentials(Credentials): ) iam_profile: Optional[str] = None iam_duration_seconds: int = 900 - search_path: Optional[str] = None - keepalives_idle: int = 4 + search_path: Optional[str] = None # TODO: Not supported in redshift python connector + keepalives_idle: int = 4 # TODO: Not supported in redshift python connector autocreate: bool = False db_groups: List[str] = field(default_factory=list) ra3_node: Optional[bool] = False @@ -65,21 +66,153 @@ class RedshiftCredentials(Credentials): sslkey: Optional[str] = None sslrootcert: Optional[str] = None application_name: Optional[str] = "dbt" - retries: int = 1 + retries: int = 0 # this is in-built into redshift python connector + + _ALIASES = {"dbname": "database", "pass": "password"} @property def type(self): return "redshift" def _connection_keys(self): - keys = super()._connection_keys() - return keys + ("method", "cluster_id", "iam_profile", "iam_duration_seconds") + return "method", "cluster_id", "iam_profile", "iam_duration_seconds" @property def unique_field(self) -> str: return self.host +class RedshiftConnectMethodFactory: + credentials: RedshiftCredentials + + def __init__(self, credentials): + self.credentials = credentials + + def get_connect_method(self): + method = self.credentials.method + # Support missing 'method' for backwards compatibility + if method == RedshiftConnectionMethod.DATABASE or method is None: + logger.debug("Connecting to Redshift using 'database' credentials") + # this requirement is really annoying to encode into json schema, + # so validate it here + if self.credentials.password is None: + raise dbt.exceptions.FailedToConnectException( + "'password' field is required for 'database' credentials" + ) + + def connect(): + c = redshift_connector.connect( + host=self.credentials.host, + database=self.credentials.database, + user=self.credentials.user, + password=self.credentials.password, + port=self.credentials.port if self.credentials.port else 5439, + auto_create=self.credentials.autocreate, + db_groups=self.credentials.db_groups, + ) + if self.credentials.role: + c.cursor().execute("set role {}".format(self.credentials.role)) + return c + + return connect + + elif method == RedshiftConnectionMethod.IAM: + if not self.credentials.cluster_id: + raise dbt.exceptions.FailedToConnectException( + "Failed to use IAM method, 'cluster_id' must be provided" + ) + + if self.credentials.iam_profile is None: + return self._get_iam_connect_method_from_env_vars() + else: + return self._get_iam_connect_method_with_tmp_cluster_credentials() + else: + raise dbt.exceptions.FailedToConnectException( + "Invalid 'method' in profile: '{}'".format(method) + ) + + def _get_iam_connect_method_from_env_vars(self): + aws_credentials_env_vars = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ] + + def check_if_env_vars_empty(var): + return os.environ.get(var, "") == "" + + empty_env_vars = list(filter(check_if_env_vars_empty, aws_credentials_env_vars)) + if len(empty_env_vars) > 0: + raise dbt.exceptions.FailedToConnectException( + "Failed to specify {} as environment variable(s) in shell".format(empty_env_vars) + ) + + def connect(): + c = redshift_connector.connect( + iam=True, + database=self.credentials.database, + db_user=self.credentials.user, + password="", + user="", + cluster_identifier=self.credentials.cluster_id, + access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + session_token=os.environ["AWS_SESSION_TOKEN"], + region=self.credentials.host.split(".")[2], + auto_create=self.credentials.autocreate, + db_groups=self.credentials.db_groups, + ) + if self.credentials.role: + c.cursor().execute("set role {}".format(self.credentials.role)) + return c + + return connect + + def _get_iam_connect_method_with_tmp_cluster_credentials(self): + tmp_user, tmp_password = self._get_tmp_iam_cluster_credentials() + + def connect(): + c = redshift_connector.connect( + iam=True, + database=self.credentials.database, + db_user=self.credentials.user, + password=tmp_password, + user=tmp_user, + cluster_identifier=self.credentials.cluster_id, + region=self.credentials.host.split(".")[2], + auto_create=self.credentials.autocreate, + db_groups=self.credentials.db_groups, + ) + if self.credentials.role: + c.cursor().execute("set role {}".format(self.credentials.role)) + return c + + return connect + + def _get_tmp_iam_cluster_credentials(self): + """Fetches temporary login credentials from AWS. The specified user + must already exist in the database, or else an error will occur""" + iam_profile = self.credentials.iam_profile + logger.debug("Connecting to Redshift using 'IAM'" + f"with profile {iam_profile}") + boto_session = boto3.Session(profile_name=iam_profile) + boto_client = boto_session.client("redshift") + + try: + cluster_creds = boto_client.get_cluster_credentials( + DbUser=self.credentials.user, + DbName=self.credentials.database, + ClusterIdentifier=self.credentials.cluster_id, + DurationSeconds=self.credentials.iam_duration_seconds, + AutoCreate=self.credentials.autocreate, + DbGroups=self.credentials.db_groups, + ) + return cluster_creds.get("DbUser"), cluster_creds.get("DbPassword") + except boto_client.exceptions.ClientError as e: + raise dbt.exceptions.FailedToConnectException( + "Unable to get temporary Redshift cluster credentials: {}".format(e) + ) + + class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" @@ -104,13 +237,10 @@ def cancel(self, connection: Connection): raise @classmethod - def get_response(cls, cursor: Any) -> AdapterResponse: - message = str(cursor.statusmessage) + def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse: rows = cursor.rowcount - status_message_parts = message.split() if message is not None else [] - status_message_strings = [part for part in status_message_parts if not part.isdigit()] - code = " ".join(status_message_strings) - return AdapterResponse(_message=message, code=code, rows_affected=rows) + message = f"{rows} cursor.rowcount" + return AdapterResponse(_message=message, rows_affected=rows) @contextmanager def exception_handler(self, sql): @@ -151,58 +281,6 @@ def fresh_transaction(self, name=None): self.commit() self.begin() - @staticmethod - def _get_connect_method(credentials): - method = credentials.method - # Support missing 'method' for backwards compatibility - if method == "database" or method is None: - logger.debug("Connecting to Redshift using 'database' credentials") - # this requirement is really annoying to encode into json schema, - # so validate it here - if credentials.password is None: - raise dbt.exceptions.FailedToConnectException( - "'password' field is required for 'database' credentials" - ) - - def connect(): - c = redshift_connector.connect( - host=credentials.host, - database=credentials.database, - user=credentials.user, - password=credentials.password, - port=credentials.port if credentials.port else 5439, - ) - if credentials.role: - c.cursor().execute("set role {}".format(credentials.role)) - return c - - return connect - - elif method == "iam": - - def connect(): - c = redshift_connector.connect( - iam=True, - database=credentials.database, - db_user=credentials.user, - password="", - user="", - cluster_identifier=credentials.cluster_id, - access_key_id=os.environ["AWS_ACCESS_KEY_ID"], - secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], - session_token=os.environ["AWS_SESSION_TOKEN"], - region=credentials.host.split(".")[2], - ) - if credentials.role: - c.cursor().execute("set role {}".format(credentials.role)) - return c - - return connect - else: - raise dbt.exceptions.FailedToConnectException( - "Invalid 'method' in profile: '{}'".format(method) - ) - @classmethod def open(cls, connection): if connection.state == "open": @@ -210,6 +288,7 @@ def open(cls, connection): return connection credentials = connection.credentials + connect_method_factory = RedshiftConnectMethodFactory(credentials) def exponential_backoff(attempt: int): return attempt * attempt @@ -218,7 +297,7 @@ def exponential_backoff(attempt: int): return cls.retry_connection( connection, - connect=cls._get_connect_method(credentials), + connect=connect_method_factory.get_connect_method(), logger=logger, retry_limit=credentials.retries, retry_timeout=exponential_backoff, diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 1a21e4d34..b91ce90ba 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -1,9 +1,12 @@ +import os import unittest from unittest import mock -from unittest.mock import Mock +from unittest.mock import Mock, call, ANY import agate import boto3 +import dbt.exceptions +import redshift_connector from dbt.adapters.redshift import ( RedshiftAdapter, @@ -12,17 +15,10 @@ from dbt.clients import agate_helper from dbt.exceptions import FailedToConnectException +from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory from .utils import config_from_parts_or_dicts, mock_connection, TestAdapterConversions, inject_adapter -@classmethod -def fetch_cluster_credentials(*args, **kwargs): - return { - 'DbUser': 'root', - 'DbPassword': 'tmp_password' - } - - class TestRedshiftAdapter(unittest.TestCase): def setUp(self): @@ -63,28 +59,100 @@ def adapter(self): inject_adapter(self._adapter, RedshiftPlugin) return self._adapter + @mock.patch("redshift_connector.connect", Mock()) def test_implicit_database_conn(self): - creds = RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - self.assertEqual(creds, self.config.credentials) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + host='thishostshouldnotexist', + database='redshift', + user='root', + password='password', + port=5439, + auto_create=False, + db_groups=[] + ) + @mock.patch("redshift_connector.connect", Mock()) def test_explicit_database_conn(self): self.config.method = 'database' - creds = RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - self.assertEqual(creds, self.config.credentials) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + host='thishostshouldnotexist', + database='redshift', + user='root', + password='password', + port=5439, + auto_create=False, + db_groups=[] + ) + + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch.dict(os.environ, {"AWS_ACCESS_KEY_ID": "Test", "AWS_SECRET_ACCESS_KEY": "Test", + "AWS_SESSION_TOKEN": "Test"}) + def test_explicit_iam_conn_with_env_vars(self): + self.config.credentials = self.config.credentials.replace( + method='iam', + cluster_id='my_redshift', + iam_duration_seconds=1200, + host='thishostshouldnotexist.test.us-east-1' + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + database='redshift', + db_user='root', + password='', + user='', + cluster_identifier='my_redshift', + access_key_id='Test', + secret_access_key='Test', + session_token='Test', + region='us-east-1', + auto_create=False, + db_groups=[] + ) - def test_explicit_iam_conn(self): + @mock.patch('redshift_connector.connect', Mock()) + @mock.patch('boto3.Session', Mock()) + def test_explicit_iam_conn_with_tmp_cluster_credentials(self): self.config.credentials = self.config.credentials.replace( method='iam', cluster_id='my_redshift', - iam_duration_seconds=1200 + iam_duration_seconds=1200, + iam_profile='test', + host='thishostshouldnotexist.test.us-east-1' ) + connection = self.adapter.acquire_connection("dummy") + connection.handle - with mock.patch.object(RedshiftAdapter.ConnectionManager, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - creds = RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) + redshift_connector.connect.assert_called_once_with( + iam=True, + database='redshift', + password=ANY, + user=ANY, + cluster_identifier='my_redshift', + region='us-east-1', + auto_create=False, + db_groups=[], + db_user='root' + ) - expected_creds = self.config.credentials.replace(password='tmp_password') - self.assertEqual(creds, expected_creds) + @mock.patch("redshift_connector.connect", Mock()) + def test_explicit_iam_conn_error_when_environment_vars_not_specified(self): + self.config.credentials = self.config.credentials.replace( + method='iam', + cluster_id='my_redshift', + iam_duration_seconds=1200, + host='thishostshouldnotexist.test.us-east-1' + ) + connection = self.adapter.acquire_connection("dummy") + with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: + connection.handle + self.assertTrue("environment variable(s)" in context.exception.msg) def test_iam_conn_optionals(self): @@ -108,41 +176,49 @@ def test_iam_conn_optionals(self): config_from_parts_or_dicts(self.config, profile_cfg) + def test_default_session_is_not_used_when_iam_used(self): + boto3.DEFAULT_SESSION = Mock() + self.config.credentials = self.config.credentials.replace(method='iam') + self.config.credentials.cluster_id = 'clusterid' + self.config.credentials.iam_profile = 'test' + with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory.get_connect_method() + self.assertEqual( + boto3.DEFAULT_SESSION.client.call_count, + 0, + "The redshift client should not be created using " + "the default session because the session object is not thread-safe" + ) + + def test_default_session_is_not_used_when_iam_not_used(self): + boto3.DEFAULT_SESSION = Mock() + self.config.credentials = self.config.credentials.replace(method=None) + with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory.get_connect_method() + self.assertEqual( + boto3.DEFAULT_SESSION.client.call_count, 0, + "The redshift client should not be created using " + "the default session because the session object is not thread-safe" + ) + def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate self.config.credentials.method = 'badmethod' - with self.assertRaises(FailedToConnectException) as context: - with mock.patch.object(RedshiftAdapter.ConnectionManager, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory.get_connect_method() self.assertTrue('badmethod' in context.exception.msg) def test_invalid_iam_no_cluster_id(self): self.config.credentials = self.config.credentials.replace(method='iam') with self.assertRaises(FailedToConnectException) as context: - with mock.patch.object(RedshiftAdapter.ConnectionManager, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory.get_connect_method() self.assertTrue("'cluster_id' must be provided" in context.exception.msg) - def test_default_session_is_not_used_when_iam_used(self): - boto3.DEFAULT_SESSION = Mock() - self.config.credentials = self.config.credentials.replace(method='iam') - self.config.credentials.cluster_id = 'clusterid' - with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): - RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - self.assertEqual(boto3.DEFAULT_SESSION.client.call_count, 0, - "The redshift client should not be created using the default session because the session object is not thread-safe") - - def test_default_session_is_not_used_when_iam_not_used(self): - boto3.DEFAULT_SESSION = Mock() - self.config.credentials = self.config.credentials.replace(method=None) - with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): - RedshiftAdapter.ConnectionManager.get_credentials(self.config.credentials) - self.assertEqual(boto3.DEFAULT_SESSION.client.call_count, 0, - "The redshift client should not be created using the default session because the session object is not thread-safe") - def test_cancel_open_connections_empty(self): self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) @@ -154,7 +230,6 @@ def test_cancel_open_connections_master(self): def test_cancel_open_connections_single(self): master = mock_connection('master') model = mock_connection('model') - model.handle.get_backend_pid.return_value = 42 key = self.adapter.connections.get_thread_identifier() self.adapter.connections.thread_connections.update({ @@ -163,100 +238,15 @@ def test_cancel_open_connections_single(self): }) with mock.patch.object(self.adapter.connections, 'add_query') as add_query: query_result = mock.MagicMock() - add_query.return_value = (None, query_result) + cursor = mock.Mock() + cursor.fetchone.return_value = 42 + add_query.side_effect = [(None, cursor), (None, query_result)] self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) - - add_query.assert_called_once_with('select pg_terminate_backend(42)') + add_query.assert_has_calls([call('select pg_backend_pid()'), call('select pg_terminate_backend(42)')]) master.handle.get_backend_pid.assert_not_called() - @mock.patch('dbt.adapters.postgres.connections.psycopg2') - def test_default_keepalive(self, psycopg2): - connection = self.adapter.acquire_connection('dummy') - - psycopg2.connect.assert_not_called() - connection.handle - psycopg2.connect.assert_called_once_with( - dbname='redshift', - user='root', - host='thishostshouldnotexist', - password='password', - port=5439, - connect_timeout=10, - keepalives_idle=4, - application_name='dbt' - ) - - @mock.patch('dbt.adapters.postgres.connections.psycopg2') - def test_changed_keepalive(self, psycopg2): - self.config.credentials = self.config.credentials.replace(keepalives_idle=5) - connection = self.adapter.acquire_connection('dummy') - - psycopg2.connect.assert_not_called() - connection.handle - psycopg2.connect.assert_called_once_with( - dbname='redshift', - user='root', - host='thishostshouldnotexist', - password='password', - port=5439, - connect_timeout=10, - keepalives_idle=5, - application_name='dbt') - - @mock.patch('dbt.adapters.postgres.connections.psycopg2') - def test_search_path(self, psycopg2): - self.config.credentials = self.config.credentials.replace(search_path="test") - connection = self.adapter.acquire_connection('dummy') - - psycopg2.connect.assert_not_called() - connection.handle - psycopg2.connect.assert_called_once_with( - dbname='redshift', - user='root', - host='thishostshouldnotexist', - password='password', - port=5439, - connect_timeout=10, - options="-c search_path=test", - keepalives_idle=4, - application_name='dbt') - - @mock.patch('dbt.adapters.postgres.connections.psycopg2') - def test_search_path_with_space(self, psycopg2): - self.config.credentials = self.config.credentials.replace(search_path="test test") - connection = self.adapter.acquire_connection('dummy') - - psycopg2.connect.assert_not_called() - connection.handle - psycopg2.connect.assert_called_once_with( - dbname='redshift', - user='root', - host='thishostshouldnotexist', - password='password', - port=5439, - connect_timeout=10, - options=r"-c search_path=test\ test", - keepalives_idle=4, - application_name='dbt') - - @mock.patch('dbt.adapters.postgres.connections.psycopg2') - def test_set_zero_keepalive(self, psycopg2): - self.config.credentials = self.config.credentials.replace(keepalives_idle=0) - connection = self.adapter.acquire_connection('dummy') - - psycopg2.connect.assert_not_called() - connection.handle - psycopg2.connect.assert_called_once_with( - dbname='redshift', - user='root', - host='thishostshouldnotexist', - password='password', - port=5439, - connect_timeout=10, - application_name='dbt') - def test_dbname_verification_is_case_insensitive(self): # Override adapter settings from setUp() profile_cfg = { From fbd5731554b0c97ce5150151ba6d3749c95efd67 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Tue, 3 Jan 2023 15:19:38 -0800 Subject: [PATCH 03/27] Fix _connection_keys to mimic PostgresConnectionManager --- dbt/adapters/redshift/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index e657a3750..3a38b6e0f 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -75,7 +75,7 @@ def type(self): return "redshift" def _connection_keys(self): - return "method", "cluster_id", "iam_profile", "iam_duration_seconds" + return "host", "port", "user", "database", "schema" @property def unique_field(self) -> str: From 4f98546e702564e1b7d908a1dec4c2f757aebba1 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Mon, 9 Jan 2023 11:48:00 -0800 Subject: [PATCH 04/27] Remove unneeded functions for tmp_cluster_creds and env_var creds auth due to in-built support in Redshift Python Connector --- dbt/adapters/redshift/connections.py | 158 ++++++++++----------------- tests/unit/test_redshift_adapter.py | 60 +++++----- 2 files changed, 83 insertions(+), 135 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 3a38b6e0f..2ad871b1e 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -1,9 +1,7 @@ -import os from multiprocessing import Lock from contextlib import contextmanager from typing import NewType -import boto3 from dbt.adapters.sql import SQLConnectionManager from dbt.contracts.connection import AdapterResponse, Connection, Credentials from dbt.events import AdapterLogger @@ -53,20 +51,36 @@ class RedshiftCredentials(Credentials): metadata={"description": "If using IAM auth, the name of the cluster"}, ) iam_profile: Optional[str] = None - iam_duration_seconds: int = 900 - search_path: Optional[str] = None # TODO: Not supported in redshift python connector - keepalives_idle: int = 4 # TODO: Not supported in redshift python connector + iam_duration_seconds: Optional[ + int + ] = 900 # TODO:Not supported in redshift python connector, should we remove this? + search_path: Optional[ + str + ] = None # TODO:Not supported in redshift python connector, should we remove this? + keepalives_idle: int = ( + 4 # TODO:Not supported in redshift python connector, should we remove this? + ) autocreate: bool = False db_groups: List[str] = field(default_factory=list) - ra3_node: Optional[bool] = False + ra3_node: Optional[ + bool + ] = False # TODO:Need to confirm if there is special handling needed for this option connect_timeout: int = 10 role: Optional[str] = None sslmode: Optional[str] = None - sslcert: Optional[str] = None - sslkey: Optional[str] = None - sslrootcert: Optional[str] = None + sslcert: Optional[ + str + ] = None # TODO:Not supported by redshift python connector, should we remove this? + sslkey: Optional[ + str + ] = None # TODO:Not supported by redshift python connector, should we remove this? + sslrootcert: Optional[ + str + ] = None # TODO:Not supported by redshift python connector, should we remove this? application_name: Optional[str] = "dbt" - retries: int = 0 # this is in-built into redshift python connector + retries: int = ( + 0 # TODO:Retries are done by redshift python connector natively, is this required? + ) _ALIASES = {"dbname": "database", "pass": "password"} @@ -90,9 +104,21 @@ def __init__(self, credentials): def get_connect_method(self): method = self.credentials.method + kwargs = { + "host": self.credentials.host, + "database": self.credentials.database, + "port": self.credentials.port if self.credentials.port else 5439, + "auto_create": self.credentials.autocreate, + "db_groups": self.credentials.db_groups, + "region": self.credentials.host.split(".")[2], + "application_name": self.credentials.application_name, + "timeout": self.credentials.connect_timeout, + } + if self.credentials.sslmode: + kwargs["sslmode"] = self.credentials.sslmode + # Support missing 'method' for backwards compatibility if method == RedshiftConnectionMethod.DATABASE or method is None: - logger.debug("Connecting to Redshift using 'database' credentials") # this requirement is really annoying to encode into json schema, # so validate it here if self.credentials.password is None: @@ -101,14 +127,9 @@ def get_connect_method(self): ) def connect(): + logger.debug("Connecting to redshift with username/password based auth...") c = redshift_connector.connect( - host=self.credentials.host, - database=self.credentials.database, - user=self.credentials.user, - password=self.credentials.password, - port=self.credentials.port if self.credentials.port else 5439, - auto_create=self.credentials.autocreate, - db_groups=self.credentials.db_groups, + user=self.credentials.user, password=self.credentials.password, **kwargs ) if self.credentials.role: c.cursor().execute("set role {}".format(self.credentials.role)) @@ -122,96 +143,27 @@ def connect(): "Failed to use IAM method, 'cluster_id' must be provided" ) - if self.credentials.iam_profile is None: - return self._get_iam_connect_method_from_env_vars() - else: - return self._get_iam_connect_method_with_tmp_cluster_credentials() + def connect(): + logger.debug("Connecting to redshift with IAM based auth...") + c = redshift_connector.connect( + iam=True, + db_user=self.credentials.user, + password="", + user="", + cluster_identifier=self.credentials.cluster_id, + profile=self.credentials.iam_profile, + **kwargs, + ) + if self.credentials.role: + c.cursor().execute("set role {}".format(self.credentials.role)) + return c + + return connect else: raise dbt.exceptions.FailedToConnectException( "Invalid 'method' in profile: '{}'".format(method) ) - def _get_iam_connect_method_from_env_vars(self): - aws_credentials_env_vars = [ - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - "AWS_SESSION_TOKEN", - ] - - def check_if_env_vars_empty(var): - return os.environ.get(var, "") == "" - - empty_env_vars = list(filter(check_if_env_vars_empty, aws_credentials_env_vars)) - if len(empty_env_vars) > 0: - raise dbt.exceptions.FailedToConnectException( - "Failed to specify {} as environment variable(s) in shell".format(empty_env_vars) - ) - - def connect(): - c = redshift_connector.connect( - iam=True, - database=self.credentials.database, - db_user=self.credentials.user, - password="", - user="", - cluster_identifier=self.credentials.cluster_id, - access_key_id=os.environ["AWS_ACCESS_KEY_ID"], - secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], - session_token=os.environ["AWS_SESSION_TOKEN"], - region=self.credentials.host.split(".")[2], - auto_create=self.credentials.autocreate, - db_groups=self.credentials.db_groups, - ) - if self.credentials.role: - c.cursor().execute("set role {}".format(self.credentials.role)) - return c - - return connect - - def _get_iam_connect_method_with_tmp_cluster_credentials(self): - tmp_user, tmp_password = self._get_tmp_iam_cluster_credentials() - - def connect(): - c = redshift_connector.connect( - iam=True, - database=self.credentials.database, - db_user=self.credentials.user, - password=tmp_password, - user=tmp_user, - cluster_identifier=self.credentials.cluster_id, - region=self.credentials.host.split(".")[2], - auto_create=self.credentials.autocreate, - db_groups=self.credentials.db_groups, - ) - if self.credentials.role: - c.cursor().execute("set role {}".format(self.credentials.role)) - return c - - return connect - - def _get_tmp_iam_cluster_credentials(self): - """Fetches temporary login credentials from AWS. The specified user - must already exist in the database, or else an error will occur""" - iam_profile = self.credentials.iam_profile - logger.debug("Connecting to Redshift using 'IAM'" + f"with profile {iam_profile}") - boto_session = boto3.Session(profile_name=iam_profile) - boto_client = boto_session.client("redshift") - - try: - cluster_creds = boto_client.get_cluster_credentials( - DbUser=self.credentials.user, - DbName=self.credentials.database, - ClusterIdentifier=self.credentials.cluster_id, - DurationSeconds=self.credentials.iam_duration_seconds, - AutoCreate=self.credentials.autocreate, - DbGroups=self.credentials.db_groups, - ) - return cluster_creds.get("DbUser"), cluster_creds.get("DbPassword") - except boto_client.exceptions.ClientError as e: - raise dbt.exceptions.FailedToConnectException( - "Unable to get temporary Redshift cluster credentials: {}".format(e) - ) - class RedshiftConnectionManager(SQLConnectionManager): TYPE = "redshift" diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index b91ce90ba..e879dc1cb 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -1,11 +1,9 @@ -import os import unittest from unittest import mock -from unittest.mock import Mock, call, ANY +from unittest.mock import Mock, call import agate import boto3 -import dbt.exceptions import redshift_connector from dbt.adapters.redshift import ( @@ -28,7 +26,7 @@ def setUp(self): 'type': 'redshift', 'dbname': 'redshift', 'user': 'root', - 'host': 'thishostshouldnotexist', + 'host': 'thishostshouldnotexist.test.us-east-1', 'pass': 'password', 'port': 5439, 'schema': 'public' @@ -64,13 +62,16 @@ def test_implicit_database_conn(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - host='thishostshouldnotexist', + host='thishostshouldnotexist.test.us-east-1', database='redshift', user='root', password='password', port=5439, auto_create=False, - db_groups=[] + db_groups=[], + application_name='dbt', + timeout=10, + region='us-east-1' ) @mock.patch("redshift_connector.connect", Mock()) @@ -80,19 +81,20 @@ def test_explicit_database_conn(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - host='thishostshouldnotexist', + host='thishostshouldnotexist.test.us-east-1', database='redshift', user='root', password='password', port=5439, auto_create=False, - db_groups=[] + db_groups=[], + region='us-east-1', + application_name='dbt', + timeout=10 ) @mock.patch("redshift_connector.connect", Mock()) - @mock.patch.dict(os.environ, {"AWS_ACCESS_KEY_ID": "Test", "AWS_SECRET_ACCESS_KEY": "Test", - "AWS_SESSION_TOKEN": "Test"}) - def test_explicit_iam_conn_with_env_vars(self): + def test_explicit_iam_conn_without_profile(self): self.config.credentials = self.config.credentials.replace( method='iam', cluster_id='my_redshift', @@ -103,22 +105,24 @@ def test_explicit_iam_conn_with_env_vars(self): connection.handle redshift_connector.connect.assert_called_once_with( iam=True, + host='thishostshouldnotexist.test.us-east-1', database='redshift', db_user='root', password='', user='', cluster_identifier='my_redshift', - access_key_id='Test', - secret_access_key='Test', - session_token='Test', region='us-east-1', auto_create=False, - db_groups=[] + db_groups=[], + profile=None, + application_name='dbt', + timeout=10, + port=5439 ) @mock.patch('redshift_connector.connect', Mock()) @mock.patch('boto3.Session', Mock()) - def test_explicit_iam_conn_with_tmp_cluster_credentials(self): + def test_explicit_iam_conn_with_profile(self): self.config.credentials = self.config.credentials.replace( method='iam', cluster_id='my_redshift', @@ -131,28 +135,20 @@ def test_explicit_iam_conn_with_tmp_cluster_credentials(self): redshift_connector.connect.assert_called_once_with( iam=True, + host='thishostshouldnotexist.test.us-east-1', database='redshift', - password=ANY, - user=ANY, cluster_identifier='my_redshift', region='us-east-1', auto_create=False, db_groups=[], - db_user='root' - ) - - @mock.patch("redshift_connector.connect", Mock()) - def test_explicit_iam_conn_error_when_environment_vars_not_specified(self): - self.config.credentials = self.config.credentials.replace( - method='iam', - cluster_id='my_redshift', - iam_duration_seconds=1200, - host='thishostshouldnotexist.test.us-east-1' + db_user='root', + password='', + user='', + profile='test', + application_name='dbt', + timeout=10, + port=5439 ) - connection = self.adapter.acquire_connection("dummy") - with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: - connection.handle - self.assertTrue("environment variable(s)" in context.exception.msg) def test_iam_conn_optionals(self): From 5319b909f5b707d659d4a30c99e33dd125c29434 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Tue, 17 Jan 2023 16:17:17 -0800 Subject: [PATCH 05/27] Resolve some TODOs --- dbt/adapters/redshift/connections.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 97229141a..0a9dc85af 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -50,20 +50,9 @@ class RedshiftCredentials(Credentials): metadata={"description": "If using IAM auth, the name of the cluster"}, ) iam_profile: Optional[str] = None - iam_duration_seconds: Optional[ - int - ] = 900 # TODO:Not supported in redshift python connector, should we remove this? - search_path: Optional[ - str - ] = None # TODO:Not supported in redshift python connector, should we remove this? - keepalives_idle: int = ( - 4 # TODO:Not supported in redshift python connector, should we remove this? - ) autocreate: bool = False db_groups: List[str] = field(default_factory=list) - ra3_node: Optional[ - bool - ] = False # TODO:Need to confirm if there is special handling needed for this option + ra3_node: Optional[bool] = False connect_timeout: int = 10 role: Optional[str] = None sslmode: Optional[str] = None @@ -77,9 +66,7 @@ class RedshiftCredentials(Credentials): str ] = None # TODO:Not supported by redshift python connector, should we remove this? application_name: Optional[str] = "dbt" - retries: int = ( - 0 # TODO:Retries are done by redshift python connector natively, is this required? - ) + retries: int = 1 _ALIASES = {"dbname": "database", "pass": "password"} @@ -88,7 +75,7 @@ def type(self): return "redshift" def _connection_keys(self): - return "host", "port", "user", "database", "schema" + return "host", "port", "user", "database", "schema", "method", "cluster_id", "iam_profile" @property def unique_field(self) -> str: @@ -207,7 +194,7 @@ def exception_handler(self, sql): # Raise DBT native exceptions as is. if isinstance(e, dbt.exceptions.Exception): raise - raise dbt.exceptions.DbtRuntimeError(str(e)) from e + raise RuntimeError(str(e)) from e @contextmanager def fresh_transaction(self, name=None): From 16666dbf77bf9b8a0a6bec3b6361986978863acb Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Wed, 18 Jan 2023 07:16:36 -0800 Subject: [PATCH 06/27] Fix references to old exceptions, add changelog --- .../Under the Hood-20230118-071542.yaml | 8 +++++ dbt/adapters/redshift/connections.py | 6 ++-- tests/unit/test_redshift_adapter.py | 36 ++----------------- 3 files changed, 14 insertions(+), 36 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20230118-071542.yaml diff --git a/.changes/unreleased/Under the Hood-20230118-071542.yaml b/.changes/unreleased/Under the Hood-20230118-071542.yaml new file mode 100644 index 000000000..afa2f05f6 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230118-071542.yaml @@ -0,0 +1,8 @@ +kind: Under the Hood +body: Replace psycopg2 connector with Redshift python connector when connecting to + Redshift +time: 2023-01-18T07:15:42.183304-08:00 +custom: + Author: sathiish-kumar + Issue: "219" + PR: "251" diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 0a9dc85af..b8889a4c3 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -108,7 +108,7 @@ def get_connect_method(self): # this requirement is really annoying to encode into json schema, # so validate it here if self.credentials.password is None: - raise dbt.exceptions.FailedToConnectException( + raise dbt.exceptions.FailedToConnectError( "'password' field is required for 'database' credentials" ) @@ -125,7 +125,7 @@ def connect(): elif method == RedshiftConnectionMethod.IAM: if not self.credentials.cluster_id: - raise dbt.exceptions.FailedToConnectException( + raise dbt.exceptions.FailedToConnectError( "Failed to use IAM method, 'cluster_id' must be provided" ) @@ -146,7 +146,7 @@ def connect(): return connect else: - raise dbt.exceptions.FailedToConnectException( + raise dbt.exceptions.FailedToConnectError( "Invalid 'method' in profile: '{}'".format(method) ) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index fe32d3c76..8f1b9b068 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -98,7 +98,6 @@ def test_explicit_iam_conn_without_profile(self): self.config.credentials = self.config.credentials.replace( method='iam', cluster_id='my_redshift', - iam_duration_seconds=1200, host='thishostshouldnotexist.test.us-east-1' ) connection = self.adapter.acquire_connection("dummy") @@ -126,12 +125,10 @@ def test_explicit_iam_conn_with_profile(self): self.config.credentials = self.config.credentials.replace( method='iam', cluster_id='my_redshift', - iam_duration_seconds=1200, iam_profile='test', host='thishostshouldnotexist.test.us-east-1' ) - connection = self.adapter.acquire_connection("dummy" - ) + connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( @@ -173,44 +170,17 @@ def test_iam_conn_optionals(self): config_from_parts_or_dicts(self.config, profile_cfg) - def test_default_session_is_not_used_when_iam_used(self): - boto3.DEFAULT_SESSION = Mock() - self.config.credentials = self.config.credentials.replace(method='iam') - self.config.credentials.cluster_id = 'clusterid' - self.config.credentials.iam_profile = 'test' - with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) - connect_method_factory.get_connect_method() - self.assertEqual( - boto3.DEFAULT_SESSION.client.call_count, - 0, - "The redshift client should not be created using " - "the default session because the session object is not thread-safe" - ) - - def test_default_session_is_not_used_when_iam_not_used(self): - boto3.DEFAULT_SESSION = Mock() - self.config.credentials = self.config.credentials.replace(method=None) - with mock.patch('dbt.adapters.redshift.connections.boto3.Session'): - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) - connect_method_factory.get_connect_method() - self.assertEqual( - boto3.DEFAULT_SESSION.client.call_count, 0, - "The redshift client should not be created using " - "the default session because the session object is not thread-safe" - ) - def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate self.config.credentials.method = 'badmethod' - with self.assertRaises(FailedToConnectException) as context: + with self.assertRaises(FailedToConnectError) as context: connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue('badmethod' in context.exception.msg) def test_invalid_iam_no_cluster_id(self): self.config.credentials = self.config.credentials.replace(method='iam') - with self.assertRaises(FailedToConnectException) as context: + with self.assertRaises(FailedToConnectError) as context: connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() From 30ae0b59741a4b3b7f7874e1bf99fd090ecf5a86 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Mon, 23 Jan 2023 07:11:35 -0800 Subject: [PATCH 07/27] Fix errors with functional tests by overriding add_query & execute and modifying multi statement execution --- dbt/adapters/redshift/connections.py | 54 ++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index b8889a4c3..b2ace7b05 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -1,7 +1,10 @@ +import re from multiprocessing import Lock from contextlib import contextmanager -from typing import NewType +from typing import NewType, Tuple +import agate +import sqlparse from dbt.adapters.sql import SQLConnectionManager from dbt.contracts.connection import AdapterResponse, Connection, Credentials from dbt.events import AdapterLogger @@ -184,9 +187,10 @@ def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse: def exception_handler(self, sql): try: yield - except redshift_connector.error.Error as e: + except redshift_connector.error.DatabaseError as e: logger.debug(f"Redshift error: {str(e)}") self.rollback_if_open() + raise dbt.exceptions.DbtDatabaseError(str(e)) except Exception as e: logger.debug("Error running SQL: {}", sql) logger.debug("Rolling back transaction.") @@ -194,7 +198,7 @@ def exception_handler(self, sql): # Raise DBT native exceptions as is. if isinstance(e, dbt.exceptions.Exception): raise - raise RuntimeError(str(e)) from e + raise dbt.exceptions.DbtRuntimeError(str(e)) from e @contextmanager def fresh_transaction(self, name=None): @@ -241,3 +245,47 @@ def exponential_backoff(attempt: int): retry_timeout=exponential_backoff, retryable_exceptions=retryable_exceptions, ) + + def execute( + self, sql: str, auto_begin: bool = False, fetch: bool = False + ) -> Tuple[AdapterResponse, agate.Table]: + _, cursor = self.add_query(sql, auto_begin) + response = self.get_response(cursor) + if fetch: + table = self.get_result_from_cursor(cursor) + else: + table = dbt.clients.agate_helper.empty_table() + return response, table + + def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): + + connection = None + cursor = None + + queries = sqlparse.split(sql) + + for query in queries: + # Strip off comments from the current query + without_comments = re.sub( + re.compile(r"(\".*?\"|\'.*?\')|(/\*.*?\*/|--[^\r\n]*$)", re.MULTILINE), + "", + query, + ).strip() + + if without_comments == "": + continue + + connection, cursor = super().add_query( + query, auto_begin, bindings=bindings, abridge_sql_log=abridge_sql_log + ) + + if cursor is None: + conn = self.get_thread_connection() + conn_name = conn.name if conn and conn.name else "" + raise dbt.exceptions.DbtRuntimeError(f"Tried to run invalid SQL: {sql} on {conn_name}") + + return connection, cursor + + @classmethod + def get_credentials(cls, credentials): + return credentials From bfe8678278ffb41113abad2c3ca03c74562f3887 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Tue, 24 Jan 2023 11:28:34 -0800 Subject: [PATCH 08/27] Attempt to fix integration tests by adding `valid_incremental_strategies` in impl.py --- dbt/adapters/redshift/connections.py | 11 +---------- dbt/adapters/redshift/impl.py | 9 +++++++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index b2ace7b05..ba56440f9 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -56,18 +56,9 @@ class RedshiftCredentials(Credentials): autocreate: bool = False db_groups: List[str] = field(default_factory=list) ra3_node: Optional[bool] = False - connect_timeout: int = 10 + connect_timeout: int = 30 role: Optional[str] = None sslmode: Optional[str] = None - sslcert: Optional[ - str - ] = None # TODO:Not supported by redshift python connector, should we remove this? - sslkey: Optional[ - str - ] = None # TODO:Not supported by redshift python connector, should we remove this? - sslrootcert: Optional[ - str - ] = None # TODO:Not supported by redshift python connector, should we remove this? application_name: Optional[str] = "dbt" retries: int = 1 diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index f84d85485..6b65d6b7d 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -90,3 +90,12 @@ def _get_catalog_schemas(self, manifest): self.type(), exc.msg ) ) + + def valid_incremental_strategies(self): + """The set of standard builtin strategies which this adapter supports out-of-the-box. + Not used to validate custom strategies defined by end users. + """ + return ["append", "delete+insert"] + + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: + return f"{add_to} + interval '{number} {interval}'" From c8a18d8b46a5548d631a563c24cc8e856afd048e Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Tue, 24 Jan 2023 11:32:40 -0800 Subject: [PATCH 09/27] Fix unit tests --- tests/unit/test_redshift_adapter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 8f1b9b068..223895a2e 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -70,7 +70,7 @@ def test_implicit_database_conn(self): auto_create=False, db_groups=[], application_name='dbt', - timeout=10, + timeout=30, region='us-east-1' ) @@ -90,7 +90,7 @@ def test_explicit_database_conn(self): db_groups=[], region='us-east-1', application_name='dbt', - timeout=10 + timeout=30 ) @mock.patch("redshift_connector.connect", Mock()) @@ -115,7 +115,7 @@ def test_explicit_iam_conn_without_profile(self): db_groups=[], profile=None, application_name='dbt', - timeout=10, + timeout=30, port=5439 ) @@ -144,7 +144,7 @@ def test_explicit_iam_conn_with_profile(self): user='', profile='test', application_name='dbt', - timeout=10, + timeout=30, port=5439 ) From 40e0fe5d69e730f4ed70677982cf266ed1700b92 Mon Sep 17 00:00:00 2001 From: Sathiish Kumar Date: Wed, 25 Jan 2023 06:30:09 -0800 Subject: [PATCH 10/27] Attempt to fix integration tests --- dbt/adapters/redshift/connections.py | 2 +- tests/integration/sources_test/test_sources.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index ba56440f9..0bda82372 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -171,7 +171,7 @@ def cancel(self, connection: Connection): @classmethod def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse: rows = cursor.rowcount - message = f"{rows} cursor.rowcount" + message = f"cursor.rowcount = {rows}" return AdapterResponse(_message=message, rows_affected=rows) @contextmanager diff --git a/tests/integration/sources_test/test_sources.py b/tests/integration/sources_test/test_sources.py index a8b7017b9..433b3037c 100644 --- a/tests/integration/sources_test/test_sources.py +++ b/tests/integration/sources_test/test_sources.py @@ -129,7 +129,7 @@ def _assert_freshness_results(self, path, state): 'warn_after': {'count': 10, 'period': 'hour'}, 'error_after': {'count': 18, 'period': 'hour'}, }, - 'adapter_response': {}, + 'adapter_response': {'_message': 'cursor.rowcount = -1', 'rows_affected': -1}, 'thread_id': AnyStringWith('Thread-'), 'execution_time': AnyFloat(), 'timing': [ From 66c1594379c2d5e09f4d79f41bfa9425922d221e Mon Sep 17 00:00:00 2001 From: jiezhec Date: Thu, 26 Jan 2023 13:28:26 -0800 Subject: [PATCH 11/27] add unit tests for execute --- tests/unit/test_redshift_adapter.py | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 223895a2e..75fd786e2 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -4,6 +4,7 @@ import agate import boto3 +import dbt import redshift_connector from dbt.adapters.redshift import ( @@ -247,6 +248,36 @@ def test_dbname_verification_is_case_insensitive(self): self._adapter = RedshiftAdapter(self.config) self.adapter.verify_database('redshift') + def test_execute_with_fetch(self): + cursor = mock.Mock() + table = dbt.clients.agate_helper.empty_table() + with mock.patch.object(self.adapter.connections, 'add_query') as mock_add_query: + mock_add_query.return_value = ( + None, cursor) # when mock_add_query is called, it will always return None, cursor + with mock.patch.object(self.adapter.connections, 'get_response') as mock_get_response: + mock_get_response.return_value = None + with mock.patch.object(self.adapter.connections, + 'get_result_from_cursor') as mock_get_result_from_cursor: + mock_get_result_from_cursor.return_value = table + self.adapter.connections.execute(sql="select * from test", fetch=True) + mock_add_query.assert_called_once_with('select * from test', False) + mock_get_result_from_cursor.assert_called_once_with(cursor) + mock_get_response.assert_called_once_with(cursor) + + def test_execute_without_fetch(self): + cursor = mock.Mock() + with mock.patch.object(self.adapter.connections, 'add_query') as mock_add_query: + mock_add_query.return_value = ( + None, cursor) # when mock_add_query is called, it will always return None, cursor + with mock.patch.object(self.adapter.connections, 'get_response') as mock_get_response: + mock_get_response.return_value = None + with mock.patch.object(self.adapter.connections, + 'get_result_from_cursor') as mock_get_result_from_cursor: + self.adapter.connections.execute(sql="select * from test2", fetch=False) + mock_add_query.assert_called_once_with('select * from test2', False) + mock_get_result_from_cursor.assert_not_called() + mock_get_response.assert_called_once_with(cursor) + class TestRedshiftAdapterConversions(TestAdapterConversions): def test_convert_text_type(self): From cfad7ffb6a5bd38d4ddd917a77b6a4ca8a35ad74 Mon Sep 17 00:00:00 2001 From: jiezhec Date: Mon, 30 Jan 2023 10:25:46 -0800 Subject: [PATCH 12/27] add unit tests for add_query --- tests/unit/test_redshift_adapter.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 75fd786e2..03b80265b 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -13,7 +13,9 @@ ) from dbt.clients import agate_helper from dbt.exceptions import FailedToConnectError - +from dbt.adapters.base import BaseConnectionManager +from dbt.adapters.redshift.connections import RedshiftConnectionManager +from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory from .utils import config_from_parts_or_dicts, mock_connection, TestAdapterConversions, inject_adapter @@ -278,6 +280,20 @@ def test_execute_without_fetch(self): mock_get_result_from_cursor.assert_not_called() mock_get_response.assert_called_once_with(cursor) + def test_add_query_with_no_cursor(self): + with mock.patch.object(self.adapter.connections, 'get_thread_connection') as mock_get_thread_connection: + mock_get_thread_connection.return_value = None + with self.assertRaisesRegex(dbt.exceptions.DbtRuntimeError, + 'Tried to run invalid SQL: on '): + self.adapter.connections.add_query(sql="") + mock_get_thread_connection.assert_called_once() + + def test_add_query_success(self): + cursor = mock.Mock() + with mock.patch.object(dbt.adapters.redshift.connections.SQLConnectionManager, 'add_query') as mock_add_query: + mock_add_query.return_value = None, cursor + self.adapter.connections.add_query('select * from test3') + mock_add_query.assert_called_once_with('select * from test3', True, bindings=None, abridge_sql_log=False) class TestRedshiftAdapterConversions(TestAdapterConversions): def test_convert_text_type(self): From 12eb89bde0520c0ae2d059bdad5321eeefc1d62e Mon Sep 17 00:00:00 2001 From: jiezhec Date: Mon, 30 Jan 2023 17:04:01 -0800 Subject: [PATCH 13/27] make get_connection_method work with serverless --- dbt/adapters/redshift/connections.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 0bda82372..ffb021359 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -118,9 +118,10 @@ def connect(): return connect elif method == RedshiftConnectionMethod.IAM: - if not self.credentials.cluster_id: + if not self.credentials.cluster_id and "serverless" not in self.credentials.host: raise dbt.exceptions.FailedToConnectError( - "Failed to use IAM method, 'cluster_id' must be provided" + "Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. " + "'host' must be provided for serverless endpoint." ) def connect(): From d3113ca06dedf6f7e39294c2ee84b834cc80491c Mon Sep 17 00:00:00 2001 From: jiezhec Date: Tue, 31 Jan 2023 10:21:49 -0800 Subject: [PATCH 14/27] add unit tests for serverless iam connections --- tests/unit/test_redshift_adapter.py | 59 +++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 03b80265b..489e5d0c1 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -13,9 +13,6 @@ ) from dbt.clients import agate_helper from dbt.exceptions import FailedToConnectError -from dbt.adapters.base import BaseConnectionManager -from dbt.adapters.redshift.connections import RedshiftConnectionManager -from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory from .utils import config_from_parts_or_dicts, mock_connection, TestAdapterConversions, inject_adapter @@ -151,6 +148,62 @@ def test_explicit_iam_conn_with_profile(self): port=5439 ) + @mock.patch('redshift_connector.connect', Mock()) + @mock.patch('boto3.Session', Mock()) + def test_explicit_iam_serverless_with_profile(self): + self.config.credentials = self.config.credentials.replace( + method='iam', + iam_profile='test', + host='doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com' + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host='doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com', + database='redshift', + cluster_identifier=None, + region='us-east-2', + auto_create=False, + db_groups=[], + db_user='root', + password='', + user='', + profile='test', + application_name='dbt', + timeout=30, + port=5439 + ) + + @mock.patch('redshift_connector.connect', Mock()) + @mock.patch('boto3.Session', Mock()) + def test_serverless_iam_failure(self): + self.config.credentials = self.config.credentials.replace( + method='iam', + iam_profile='test', + host='doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com' + ) + with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + host='doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com', + database='redshift', + cluster_identifier=None, + region='us-east-2', + auto_create=False, + db_groups=[], + db_user='root', + password='', + user='', + profile='test', + application_name='dbt', + timeout=30, + port=5439 + ) + self.assertTrue("'host' must be provided" in context.exception.msg) + def test_iam_conn_optionals(self): profile_cfg = { From ccfebc8265eb313d4434968d7928162317951b54 Mon Sep 17 00:00:00 2001 From: jiezhec Date: Thu, 2 Feb 2023 14:13:30 -0800 Subject: [PATCH 15/27] support auth_profile --- dbt/adapters/redshift/connections.py | 61 ++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index ffb021359..36ff5cef8 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -1,3 +1,4 @@ +import os import re from multiprocessing import Lock from contextlib import contextmanager @@ -39,13 +40,16 @@ def json_schema(self): class RedshiftConnectionMethod(StrEnum): DATABASE = "database" IAM = "iam" + AUTH_PROFILE = "auth_profile" + IDP = "IdP" @dataclass class RedshiftCredentials(Credentials): - host: str - user: str + host: Optional[str] + user: Optional[str] port: Port + region: Optional[str] = None method: str = RedshiftConnectionMethod.DATABASE # type: ignore password: Optional[str] = None # type: ignore cluster_id: Optional[str] = field( @@ -61,6 +65,7 @@ class RedshiftCredentials(Credentials): sslmode: Optional[str] = None application_name: Optional[str] = "dbt" retries: int = 1 + auth_profile: Optional[str] = None _ALIASES = {"dbname": "database", "pass": "password"} @@ -69,11 +74,20 @@ def type(self): return "redshift" def _connection_keys(self): - return "host", "port", "user", "database", "schema", "method", "cluster_id", "iam_profile" + return ( + "host", + "port", + "user", + "database", + "schema", + "method", + "cluster_id", + "iam_profile", + ) @property def unique_field(self) -> str: - return self.host + return self.host if self.host else self.database class RedshiftConnectMethodFactory: @@ -85,15 +99,18 @@ def __init__(self, credentials): def get_connect_method(self): method = self.credentials.method kwargs = { - "host": self.credentials.host, + "host": "", + "region": self.credentials.region, "database": self.credentials.database, "port": self.credentials.port if self.credentials.port else 5439, "auto_create": self.credentials.autocreate, "db_groups": self.credentials.db_groups, - "region": self.credentials.host.split(".")[2], "application_name": self.credentials.application_name, "timeout": self.credentials.connect_timeout, } + if method != RedshiftConnectionMethod.AUTH_PROFILE: + kwargs["host"] = self.credentials.host + kwargs["region"] = self.credentials.host.split(".")[2] if self.credentials.sslmode: kwargs["sslmode"] = self.credentials.sslmode @@ -109,7 +126,36 @@ def get_connect_method(self): def connect(): logger.debug("Connecting to redshift with username/password based auth...") c = redshift_connector.connect( - user=self.credentials.user, password=self.credentials.password, **kwargs + user=self.credentials.user, + password=self.credentials.password, + **kwargs, + ) + if self.credentials.role: + c.cursor().execute("set role {}".format(self.credentials.role)) + return c + + return connect + + elif method == RedshiftConnectionMethod.AUTH_PROFILE: + if not self.credentials.auth_profile: + raise dbt.exceptions.FailedToConnectError( + "Failed to use auth profile method. 'auth_profile' must be provided." + ) + if not self.credentials.region: + raise dbt.exceptions.FailedToConnectError( + "Failed to use auth profile method. 'region' must be provided." + ) + + def connect(): + logger.debug("Connecting to redshift with authentication profile...") + c = redshift_connector.connect( + iam=True, + access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + session_token=os.environ["AWS_SESSION_TOKEN"], + db_user=self.credentials.user, + auth_profile=self.credentials.auth_profile, + **kwargs, ) if self.credentials.role: c.cursor().execute("set role {}".format(self.credentials.role)) @@ -140,6 +186,7 @@ def connect(): return c return connect + else: raise dbt.exceptions.FailedToConnectError( "Invalid 'method' in profile: '{}'".format(method) From 1dc3ebebc6e36d518d12a0dac301b56d42f66031 Mon Sep 17 00:00:00 2001 From: jiezhec Date: Thu, 2 Feb 2023 14:14:21 -0800 Subject: [PATCH 16/27] add unit tests for auth_profile connection method --- tests/unit/test_redshift_adapter.py | 445 ++++++++++++++++------------ 1 file changed, 263 insertions(+), 182 deletions(-) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 489e5d0c1..5588c6aac 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -1,3 +1,4 @@ +import os import unittest from unittest import mock from unittest.mock import Mock, call @@ -14,37 +15,41 @@ from dbt.clients import agate_helper from dbt.exceptions import FailedToConnectError from dbt.adapters.redshift.connections import RedshiftConnectMethodFactory -from .utils import config_from_parts_or_dicts, mock_connection, TestAdapterConversions, inject_adapter +from .utils import ( + config_from_parts_or_dicts, + mock_connection, + TestAdapterConversions, + inject_adapter, +) class TestRedshiftAdapter(unittest.TestCase): - def setUp(self): profile_cfg = { - 'outputs': { - 'test': { - 'type': 'redshift', - 'dbname': 'redshift', - 'user': 'root', - 'host': 'thishostshouldnotexist.test.us-east-1', - 'pass': 'password', - 'port': 5439, - 'schema': 'public' + "outputs": { + "test": { + "type": "redshift", + "dbname": "redshift", + "user": "root", + "host": "thishostshouldnotexist.test.us-east-1", + "pass": "password", + "port": 5439, + "schema": "public", } }, - 'target': 'test' + "target": "test", } project_cfg = { - 'name': 'X', - 'version': '0.1', - 'profile': 'test', - 'project-root': '/tmp/dbt/does-not-exist', - 'quoting': { - 'identifier': False, - 'schema': True, + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "quoting": { + "identifier": False, + "schema": True, }, - 'config-version': 2, + "config-version": 2, } self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) @@ -62,182 +67,223 @@ def test_implicit_database_conn(self): connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - user='root', - password='password', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + user="root", + password="password", port=5439, auto_create=False, db_groups=[], - application_name='dbt', + application_name="dbt", timeout=30, - region='us-east-1' + region="us-east-1", ) @mock.patch("redshift_connector.connect", Mock()) def test_explicit_database_conn(self): - self.config.method = 'database' + self.config.method = "database" connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - user='root', - password='password', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + user="root", + password="password", port=5439, auto_create=False, db_groups=[], - region='us-east-1', - application_name='dbt', - timeout=30 + region="us-east-1", + application_name="dbt", + timeout=30, ) @mock.patch("redshift_connector.connect", Mock()) def test_explicit_iam_conn_without_profile(self): self.config.credentials = self.config.credentials.replace( - method='iam', - cluster_id='my_redshift', - host='thishostshouldnotexist.test.us-east-1' + method="iam", + cluster_id="my_redshift", + host="thishostshouldnotexist.test.us-east-1", ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - db_user='root', - password='', - user='', - cluster_identifier='my_redshift', - region='us-east-1', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + db_user="root", + password="", + user="", + cluster_identifier="my_redshift", + region="us-east-1", auto_create=False, db_groups=[], profile=None, - application_name='dbt', + application_name="dbt", timeout=30, - port=5439 + port=5439, ) - @mock.patch('redshift_connector.connect', Mock()) - @mock.patch('boto3.Session', Mock()) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) def test_explicit_iam_conn_with_profile(self): self.config.credentials = self.config.credentials.replace( - method='iam', - cluster_id='my_redshift', - iam_profile='test', - host='thishostshouldnotexist.test.us-east-1' + method="iam", + cluster_id="my_redshift", + iam_profile="test", + host="thishostshouldnotexist.test.us-east-1", ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='thishostshouldnotexist.test.us-east-1', - database='redshift', - cluster_identifier='my_redshift', - region='us-east-1', + host="thishostshouldnotexist.test.us-east-1", + database="redshift", + cluster_identifier="my_redshift", + region="us-east-1", auto_create=False, db_groups=[], - db_user='root', - password='', - user='', - profile='test', - application_name='dbt', + db_user="root", + password="", + user="", + profile="test", + application_name="dbt", timeout=30, - port=5439 + port=5439, ) - @mock.patch('redshift_connector.connect', Mock()) - @mock.patch('boto3.Session', Mock()) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) def test_explicit_iam_serverless_with_profile(self): self.config.credentials = self.config.credentials.replace( - method='iam', - iam_profile='test', - host='doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com' + method="iam", + iam_profile="test", + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com', - database='redshift', + host="doesnotexist.1233.us-east-2.redshift-serverless.amazonaws.com", + database="redshift", cluster_identifier=None, - region='us-east-2', + region="us-east-2", auto_create=False, db_groups=[], - db_user='root', - password='', - user='', - profile='test', - application_name='dbt', + db_user="root", + password="", + user="", + profile="test", + application_name="dbt", timeout=30, - port=5439 + port=5439, ) - @mock.patch('redshift_connector.connect', Mock()) - @mock.patch('boto3.Session', Mock()) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) def test_serverless_iam_failure(self): self.config.credentials = self.config.credentials.replace( - method='iam', - iam_profile='test', - host='doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com' + method="iam", + iam_profile="test", + host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", ) with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - host='doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com', - database='redshift', + host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", + database="redshift", cluster_identifier=None, - region='us-east-2', + region="us-east-2", auto_create=False, db_groups=[], - db_user='root', - password='', - user='', - profile='test', - application_name='dbt', + db_user="root", + password="", + user="", + profile="test", + application_name="dbt", timeout=30, - port=5439 - ) + port=5439, + ) self.assertTrue("'host' must be provided" in context.exception.msg) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "someid", + "AWS_SECRET_ACCESS_KEY": "somekey", + "AWS_SESSION_TOKEN": "somekey", + }, + ) + @mock.patch("boto3.Session", Mock()) + def test_auth_profile_connect_success(self): + self.config.credentials = self.config.credentials.replace( + method="auth_profile", + auth_profile="testprofile", + database="", + user="", + region="us-east-1", + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + session_token=os.environ["AWS_SESSION_TOKEN"], + db_user="", + auth_profile="testprofile", + port=5439, + auto_create=False, + db_groups=[], + application_name="dbt", + database="", + region="us-east-1", + host="", + timeout=30, + ) + def test_iam_conn_optionals(self): profile_cfg = { - 'outputs': { - 'test': { - 'type': 'redshift', - 'dbname': 'redshift', - 'user': 'root', - 'host': 'thishostshouldnotexist', - 'port': 5439, - 'schema': 'public', - 'method': 'iam', - 'cluster_id': 'my_redshift', - 'db_groups': ["my_dbgroup"], - 'autocreate': True, + "outputs": { + "test": { + "type": "redshift", + "dbname": "redshift", + "user": "root", + "host": "thishostshouldnotexist", + "port": 5439, + "schema": "public", + "method": "iam", + "cluster_id": "my_redshift", + "db_groups": ["my_dbgroup"], + "autocreate": True, } }, - 'target': 'test' + "target": "test", } config_from_parts_or_dicts(self.config, profile_cfg) def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate - self.config.credentials.method = 'badmethod' + self.config.credentials.method = "badmethod" with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = RedshiftConnectMethodFactory( + self.config.credentials + ) connect_method_factory.get_connect_method() - self.assertTrue('badmethod' in context.exception.msg) + self.assertTrue("badmethod" in context.exception.msg) def test_invalid_iam_no_cluster_id(self): - self.config.credentials = self.config.credentials.replace(method='iam') + self.config.credentials = self.config.credentials.replace(method="iam") with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) + connect_method_factory = RedshiftConnectMethodFactory( + self.config.credentials + ) connect_method_factory.get_connect_method() self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @@ -247,171 +293,206 @@ def test_cancel_open_connections_empty(self): def test_cancel_open_connections_master(self): key = self.adapter.connections.get_thread_identifier() - self.adapter.connections.thread_connections[key] = mock_connection('master') + self.adapter.connections.thread_connections[key] = mock_connection("master") self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0) def test_cancel_open_connections_single(self): - master = mock_connection('master') - model = mock_connection('model') + master = mock_connection("master") + model = mock_connection("model") key = self.adapter.connections.get_thread_identifier() - self.adapter.connections.thread_connections.update({ - key: master, - 1: model, - }) - with mock.patch.object(self.adapter.connections, 'add_query') as add_query: + self.adapter.connections.thread_connections.update( + { + key: master, + 1: model, + } + ) + with mock.patch.object(self.adapter.connections, "add_query") as add_query: query_result = mock.MagicMock() cursor = mock.Mock() cursor.fetchone.return_value = 42 add_query.side_effect = [(None, cursor), (None, query_result)] self.assertEqual(len(list(self.adapter.cancel_open_connections())), 1) - add_query.assert_has_calls([call('select pg_backend_pid()'), call('select pg_terminate_backend(42)')]) + add_query.assert_has_calls( + [ + call("select pg_backend_pid()"), + call("select pg_terminate_backend(42)"), + ] + ) master.handle.get_backend_pid.assert_not_called() def test_dbname_verification_is_case_insensitive(self): # Override adapter settings from setUp() profile_cfg = { - 'outputs': { - 'test': { - 'type': 'redshift', - 'dbname': 'Redshift', - 'user': 'root', - 'host': 'thishostshouldnotexist', - 'pass': 'password', - 'port': 5439, - 'schema': 'public' + "outputs": { + "test": { + "type": "redshift", + "dbname": "Redshift", + "user": "root", + "host": "thishostshouldnotexist", + "pass": "password", + "port": 5439, + "schema": "public", } }, - 'target': 'test' + "target": "test", } project_cfg = { - 'name': 'X', - 'version': '0.1', - 'profile': 'test', - 'project-root': '/tmp/dbt/does-not-exist', - 'quoting': { - 'identifier': False, - 'schema': True, + "name": "X", + "version": "0.1", + "profile": "test", + "project-root": "/tmp/dbt/does-not-exist", + "quoting": { + "identifier": False, + "schema": True, }, - 'config-version': 2, + "config-version": 2, } self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) self.adapter.cleanup_connections() self._adapter = RedshiftAdapter(self.config) - self.adapter.verify_database('redshift') + self.adapter.verify_database("redshift") def test_execute_with_fetch(self): cursor = mock.Mock() table = dbt.clients.agate_helper.empty_table() - with mock.patch.object(self.adapter.connections, 'add_query') as mock_add_query: + with mock.patch.object(self.adapter.connections, "add_query") as mock_add_query: mock_add_query.return_value = ( - None, cursor) # when mock_add_query is called, it will always return None, cursor - with mock.patch.object(self.adapter.connections, 'get_response') as mock_get_response: + None, + cursor, + ) # when mock_add_query is called, it will always return None, cursor + with mock.patch.object( + self.adapter.connections, "get_response" + ) as mock_get_response: mock_get_response.return_value = None - with mock.patch.object(self.adapter.connections, - 'get_result_from_cursor') as mock_get_result_from_cursor: + with mock.patch.object( + self.adapter.connections, "get_result_from_cursor" + ) as mock_get_result_from_cursor: mock_get_result_from_cursor.return_value = table - self.adapter.connections.execute(sql="select * from test", fetch=True) - mock_add_query.assert_called_once_with('select * from test', False) + self.adapter.connections.execute( + sql="select * from test", fetch=True + ) + mock_add_query.assert_called_once_with("select * from test", False) mock_get_result_from_cursor.assert_called_once_with(cursor) mock_get_response.assert_called_once_with(cursor) def test_execute_without_fetch(self): cursor = mock.Mock() - with mock.patch.object(self.adapter.connections, 'add_query') as mock_add_query: + with mock.patch.object(self.adapter.connections, "add_query") as mock_add_query: mock_add_query.return_value = ( - None, cursor) # when mock_add_query is called, it will always return None, cursor - with mock.patch.object(self.adapter.connections, 'get_response') as mock_get_response: + None, + cursor, + ) # when mock_add_query is called, it will always return None, cursor + with mock.patch.object( + self.adapter.connections, "get_response" + ) as mock_get_response: mock_get_response.return_value = None - with mock.patch.object(self.adapter.connections, - 'get_result_from_cursor') as mock_get_result_from_cursor: - self.adapter.connections.execute(sql="select * from test2", fetch=False) - mock_add_query.assert_called_once_with('select * from test2', False) + with mock.patch.object( + self.adapter.connections, "get_result_from_cursor" + ) as mock_get_result_from_cursor: + self.adapter.connections.execute( + sql="select * from test2", fetch=False + ) + mock_add_query.assert_called_once_with("select * from test2", False) mock_get_result_from_cursor.assert_not_called() mock_get_response.assert_called_once_with(cursor) def test_add_query_with_no_cursor(self): - with mock.patch.object(self.adapter.connections, 'get_thread_connection') as mock_get_thread_connection: + with mock.patch.object( + self.adapter.connections, "get_thread_connection" + ) as mock_get_thread_connection: mock_get_thread_connection.return_value = None - with self.assertRaisesRegex(dbt.exceptions.DbtRuntimeError, - 'Tried to run invalid SQL: on '): + with self.assertRaisesRegex( + dbt.exceptions.DbtRuntimeError, "Tried to run invalid SQL: on " + ): self.adapter.connections.add_query(sql="") mock_get_thread_connection.assert_called_once() def test_add_query_success(self): cursor = mock.Mock() - with mock.patch.object(dbt.adapters.redshift.connections.SQLConnectionManager, 'add_query') as mock_add_query: + with mock.patch.object( + dbt.adapters.redshift.connections.SQLConnectionManager, "add_query" + ) as mock_add_query: mock_add_query.return_value = None, cursor - self.adapter.connections.add_query('select * from test3') - mock_add_query.assert_called_once_with('select * from test3', True, bindings=None, abridge_sql_log=False) + self.adapter.connections.add_query("select * from test3") + mock_add_query.assert_called_once_with( + "select * from test3", True, bindings=None, abridge_sql_log=False + ) + class TestRedshiftAdapterConversions(TestAdapterConversions): def test_convert_text_type(self): rows = [ - ['', 'a1', 'stringval1'], - ['', 'a2', 'stringvalasdfasdfasdfa'], - ['', 'a3', 'stringval3'], + ["", "a1", "stringval1"], + ["", "a2", "stringvalasdfasdfasdfa"], + ["", "a3", "stringval3"], ] agate_table = self._make_table_of(rows, agate.Text) - expected = ['varchar(64)', 'varchar(2)', 'varchar(22)'] + expected = ["varchar(64)", "varchar(2)", "varchar(22)"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_text_type(agate_table, col_idx) == expect def test_convert_number_type(self): rows = [ - ['', '23.98', '-1'], - ['', '12.78', '-2'], - ['', '79.41', '-3'], + ["", "23.98", "-1"], + ["", "12.78", "-2"], + ["", "79.41", "-3"], ] agate_table = self._make_table_of(rows, agate.Number) - expected = ['integer', 'float8', 'integer'] + expected = ["integer", "float8", "integer"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_number_type(agate_table, col_idx) == expect def test_convert_boolean_type(self): rows = [ - ['', 'false', 'true'], - ['', 'false', 'false'], - ['', 'false', 'true'], + ["", "false", "true"], + ["", "false", "false"], + ["", "false", "true"], ] agate_table = self._make_table_of(rows, agate.Boolean) - expected = ['boolean', 'boolean', 'boolean'] + expected = ["boolean", "boolean", "boolean"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_boolean_type(agate_table, col_idx) == expect def test_convert_datetime_type(self): rows = [ - ['', '20190101T01:01:01Z', '2019-01-01 01:01:01'], - ['', '20190102T01:01:01Z', '2019-01-01 01:01:01'], - ['', '20190103T01:01:01Z', '2019-01-01 01:01:01'], + ["", "20190101T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190102T01:01:01Z", "2019-01-01 01:01:01"], + ["", "20190103T01:01:01Z", "2019-01-01 01:01:01"], + ] + agate_table = self._make_table_of( + rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime] + ) + expected = [ + "timestamp without time zone", + "timestamp without time zone", + "timestamp without time zone", ] - agate_table = self._make_table_of(rows, [agate.DateTime, agate_helper.ISODateTime, agate.DateTime]) - expected = ['timestamp without time zone', 'timestamp without time zone', 'timestamp without time zone'] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_datetime_type(agate_table, col_idx) == expect def test_convert_date_type(self): rows = [ - ['', '2019-01-01', '2019-01-04'], - ['', '2019-01-02', '2019-01-04'], - ['', '2019-01-03', '2019-01-04'], + ["", "2019-01-01", "2019-01-04"], + ["", "2019-01-02", "2019-01-04"], + ["", "2019-01-03", "2019-01-04"], ] agate_table = self._make_table_of(rows, agate.Date) - expected = ['date', 'date', 'date'] + expected = ["date", "date", "date"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_date_type(agate_table, col_idx) == expect def test_convert_time_type(self): # dbt's default type testers actually don't have a TimeDelta at all. rows = [ - ['', '120s', '10s'], - ['', '3m', '11s'], - ['', '1h', '12s'], + ["", "120s", "10s"], + ["", "3m", "11s"], + ["", "1h", "12s"], ] agate_table = self._make_table_of(rows, agate.TimeDelta) - expected = ['varchar(24)', 'varchar(24)', 'varchar(24)'] + expected = ["varchar(24)", "varchar(24)", "varchar(24)"] for col_idx, expect in enumerate(expected): assert RedshiftAdapter.convert_time_type(agate_table, col_idx) == expect From b43d20c48795fd8b141efba8d9a8b347cd3d3883 Mon Sep 17 00:00:00 2001 From: jiezhec Date: Thu, 2 Feb 2023 21:57:34 -0800 Subject: [PATCH 17/27] add support for auth_profile --- dbt/adapters/redshift/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 36ff5cef8..286e56956 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -105,8 +105,8 @@ def get_connect_method(self): "port": self.credentials.port if self.credentials.port else 5439, "auto_create": self.credentials.autocreate, "db_groups": self.credentials.db_groups, - "application_name": self.credentials.application_name, "timeout": self.credentials.connect_timeout, + "application_name": self.credentials.application_name, } if method != RedshiftConnectionMethod.AUTH_PROFILE: kwargs["host"] = self.credentials.host From 93947e15c95eafd5b7931099ffd28830a916dd71 Mon Sep 17 00:00:00 2001 From: jiezhec Date: Mon, 6 Feb 2023 08:52:34 -0800 Subject: [PATCH 18/27] add support for idp credentials --- dbt/adapters/redshift/connections.py | 89 +++++++++++++++++++++- tests/unit/test_redshift_adapter.py | 106 ++++++++++++++++++++++----- 2 files changed, 175 insertions(+), 20 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 286e56956..ce30cf275 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -46,10 +46,10 @@ class RedshiftConnectionMethod(StrEnum): @dataclass class RedshiftCredentials(Credentials): - host: Optional[str] - user: Optional[str] port: Port region: Optional[str] = None + host: Optional[str] = None + user: Optional[str] = None method: str = RedshiftConnectionMethod.DATABASE # type: ignore password: Optional[str] = None # type: ignore cluster_id: Optional[str] = field( @@ -66,6 +66,16 @@ class RedshiftCredentials(Credentials): application_name: Optional[str] = "dbt" retries: int = 1 auth_profile: Optional[str] = None + # Azure identity provider plugin + credentials_provider: Optional[str] = None + idp_tenant: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None + preferred_role: Optional[str] = None + # Okta identity provider plugin + idp_host: Optional[str] = None + app_id: Optional[str] = None + app_name: Optional[str] = None _ALIASES = {"dbname": "database", "pass": "password"} @@ -108,7 +118,7 @@ def get_connect_method(self): "timeout": self.credentials.connect_timeout, "application_name": self.credentials.application_name, } - if method != RedshiftConnectionMethod.AUTH_PROFILE: + if method == RedshiftConnectionMethod.IAM or method == RedshiftConnectionMethod.DATABASE: kwargs["host"] = self.credentials.host kwargs["region"] = self.credentials.host.split(".")[2] if self.credentials.sslmode: @@ -162,7 +172,80 @@ def connect(): return c return connect + elif method == RedshiftConnectionMethod.IDP: + if not self.credentials.credentials_provider: + raise dbt.exceptions.FailedToConnectError( + "'credentials_provider' field is required for 'IdP' credentials" + ) + + if not self.credentials.region: + raise dbt.exceptions.FailedToConnectError( + "'region' field is required for 'IdP' credentials" + ) + + if not self.credentials.password or not self.credentials.user: + raise dbt.exceptions.FailedToConnectError( + "'password' and 'user' fields are required for 'IdP' credentials" + ) + if self.credentials.credentials_provider == "AzureCredentialsProvider": + if ( + not self.credentials.idp_tenant + or not self.credentials.client_id + or not self.credentials.client_secret + or not self.credentials.preferred_role + ): + raise dbt.exceptions.FailedToConnectError( + "'idp_tenant', 'client_id', 'client_secret', and 'preferred_role' are required for" + " Azure Credentials Provider" + ) + + def connect(): + logger.debug("Connecting to redshift with Azure Credentials Provider...") + c = redshift_connector.connect( + iam=True, + region=self.credentials.region, + database=self.credentials.database, + cluster_identifier=self.credentials.cluster_id, + credentials_provider="AzureCredentialsProvider", + user=self.credentials.user, + password=self.credentials.password, + idp_tenant=self.credentials.idp_tenant, + client_id=self.credentials.client_id, + client_secret=self.credentials.client_secret, + preferred_role=self.credentials.preferred_role, + ) + return c + + return connect + elif self.credentials.credentials_provider == "OktaCredentialsProvider": + if ( + not self.credentials.idp_host + or not self.credentials.app_id + or not self.credentials.app_name + ): + raise dbt.exceptions.FailedToConnectError( + "'idp_host', 'app_id', 'app_name' are required for" + " Okta Credentials Provider" + ) + + def connect(): + logger.debug("Connecting to redshift with Okta Credentials Provider...") + c = redshift_connector.connect( + iam=True, + region=self.credentials.region, + database=self.credentials.database, + cluster_identifier=self.credentials.cluster_id, + credentials_provider="OktaCredentialsProvider", + user=self.credentials.user, + password=self.credentials.password, + idp_host=self.credentials.idp_host, + app_id=self.credentials.app_id, + app_name=self.credentials.app_name, + ) + return c + + return connect elif method == RedshiftConnectionMethod.IAM: if not self.credentials.cluster_id and "serverless" not in self.credentials.host: raise dbt.exceptions.FailedToConnectError( diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 5588c6aac..e2291cb0d 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -181,7 +181,6 @@ def test_explicit_iam_serverless_with_profile(self): ) @mock.patch("redshift_connector.connect", Mock()) - @mock.patch("boto3.Session", Mock()) def test_serverless_iam_failure(self): self.config.credentials = self.config.credentials.replace( method="iam", @@ -191,24 +190,83 @@ def test_serverless_iam_failure(self): with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: connection = self.adapter.acquire_connection("dummy") connection.handle - redshift_connector.connect.assert_called_once_with( - iam=True, - host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com", - database="redshift", - cluster_identifier=None, - region="us-east-2", - auto_create=False, - db_groups=[], - db_user="root", - password="", - user="", - profile="test", - application_name="dbt", - timeout=30, - port=5439, - ) self.assertTrue("'host' must be provided" in context.exception.msg) + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) + def test_azure_identity_plugin(self): + self.config.credentials = self.config.credentials.replace( + method="IdP", + database='dev', + cluster_id='my-testing-cluster', + credentials_provider='AzureCredentialsProvider', + user='someuser@myazure.org', + password='somepassword', + idp_tenant='my_idp_tenant', + client_id='my_client_id', + client_secret='my_client_secret', + region = 'us-east-1', + preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + #host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com",\ + database='dev', + cluster_identifier='my-testing-cluster', + credentials_provider='AzureCredentialsProvider', + user='someuser@myazure.org', + password='somepassword', + idp_tenant='my_idp_tenant', + client_id='my_client_id', + client_secret='my_client_secret', + preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll', + region="us-east-1", + ) + @mock.patch("redshift_connector.connect", Mock()) + def test_idp_identity_no_region(self): + self.config.credentials = self.config.credentials.replace( + method="IdP", + ) + with self.assertRaises(FailedToConnectError) as context: + connect_method_factory = RedshiftConnectMethodFactory( + self.config.credentials + ) + connect_method_factory.get_connect_method() + self.assertTrue("credentials_provider" in context.exception.msg) + + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) + def test_okta_identity_plugin(self): + self.config.credentials = self.config.credentials.replace( + method="IdP", + database='dev', + cluster_id='my-testing-cluster', + credentials_provider='OktaCredentialsProvider', + user='someuser@myazure.org', + password='somepassword', + idp_host='my_idp_host', + app_id='my_first_appetizer', + app_name='dinner_party', + region='us-east-1' + + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + database='dev', + cluster_identifier='my-testing-cluster', + credentials_provider='OktaCredentialsProvider', + user='someuser@myazure.org', + password='somepassword', + idp_host='my_idp_host', + app_id='my_first_appetizer', + app_name='dinner_party', + region='us-east-1' + ) + @mock.patch("redshift_connector.connect", Mock()) @mock.patch.dict( os.environ, @@ -246,6 +304,20 @@ def test_auth_profile_connect_success(self): timeout=30, ) + + @mock.patch("redshift_connector.connect", Mock()) + def test_auth_profile_no_profile(self): + self.config.credentials = self.config.credentials.replace( + method="auth_profile", + auth_profile="", + ) + with self.assertRaises(FailedToConnectError) as context: + connect_method_factory = RedshiftConnectMethodFactory( + self.config.credentials + ) + connect_method_factory.get_connect_method() + self.assertTrue("'auth_profile' must be provided" in context.exception.msg) + def test_iam_conn_optionals(self): profile_cfg = { From fb6026834eb2c77879c199670b465a84544adc2b Mon Sep 17 00:00:00 2001 From: jiezhec Date: Mon, 6 Feb 2023 11:25:12 -0800 Subject: [PATCH 19/27] add CHANGELOG --- .changes/unreleased/Features-20230206-112433.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changes/unreleased/Features-20230206-112433.yaml diff --git a/.changes/unreleased/Features-20230206-112433.yaml b/.changes/unreleased/Features-20230206-112433.yaml new file mode 100644 index 000000000..350147dae --- /dev/null +++ b/.changes/unreleased/Features-20230206-112433.yaml @@ -0,0 +1,5 @@ +kind: Features +time: 2023-02-06T11:24:33.926088-08:00 +custom: + Author: jiezhen-chen + Issue: "6232" From c710803a37b8fda45420251619be041a96a5a09d Mon Sep 17 00:00:00 2001 From: jiezhec Date: Mon, 6 Feb 2023 12:03:28 -0800 Subject: [PATCH 20/27] change host default to None --- dbt/adapters/redshift/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index ce30cf275..22a429b98 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -109,7 +109,7 @@ def __init__(self, credentials): def get_connect_method(self): method = self.credentials.method kwargs = { - "host": "", + "host": None, "region": self.credentials.region, "database": self.credentials.database, "port": self.credentials.port if self.credentials.port else 5439, From 17fcb0316c665752c51c55cfbf5e030069306ec0 Mon Sep 17 00:00:00 2001 From: jiezhec Date: Tue, 7 Feb 2023 15:29:29 -0800 Subject: [PATCH 21/27] add unit tests and improve error messages --- dbt/adapters/redshift/connections.py | 15 +++++++------- tests/unit/test_redshift_adapter.py | 29 ++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 22a429b98..b9c4ab7fc 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -109,7 +109,7 @@ def __init__(self, credentials): def get_connect_method(self): method = self.credentials.method kwargs = { - "host": None, + "host": "", "region": self.credentials.region, "database": self.credentials.database, "port": self.credentials.port if self.credentials.port else 5439, @@ -175,17 +175,17 @@ def connect(): elif method == RedshiftConnectionMethod.IDP: if not self.credentials.credentials_provider: raise dbt.exceptions.FailedToConnectError( - "'credentials_provider' field is required for 'IdP' credentials" + "Failed to use IdP credentials. 'credentials_provider' must be provided." ) if not self.credentials.region: raise dbt.exceptions.FailedToConnectError( - "'region' field is required for 'IdP' credentials" + "Failed to use IdP credentials. 'region' must be provided." ) if not self.credentials.password or not self.credentials.user: raise dbt.exceptions.FailedToConnectError( - "'password' and 'user' fields are required for 'IdP' credentials" + "Failed to use IdP credentials. 'password' and 'user' must be provided." ) if self.credentials.credentials_provider == "AzureCredentialsProvider": @@ -196,8 +196,8 @@ def connect(): or not self.credentials.preferred_role ): raise dbt.exceptions.FailedToConnectError( - "'idp_tenant', 'client_id', 'client_secret', and 'preferred_role' are required for" - " Azure Credentials Provider" + "Failed to use Azure credential. 'idp_tenant', 'client_id', 'client_secret', " + "and 'preferred_role' must be provided" ) def connect(): @@ -225,8 +225,7 @@ def connect(): or not self.credentials.app_name ): raise dbt.exceptions.FailedToConnectError( - "'idp_host', 'app_id', 'app_name' are required for" - " Okta Credentials Provider" + "Failed to use Okta credential. 'idp_host', 'app_id', 'app_name' must be provided." ) def connect(): diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index e2291cb0d..e6e195503 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -212,7 +212,6 @@ def test_azure_identity_plugin(self): connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - #host="doesnotexist.1233.us-east-2.redshift-srvrlss.amazonaws.com",\ database='dev', cluster_identifier='my-testing-cluster', credentials_provider='AzureCredentialsProvider', @@ -224,8 +223,32 @@ def test_azure_identity_plugin(self): preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll', region="us-east-1", ) + + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) + def test_azure_no_region(self): + self.config.credentials = self.config.credentials.replace( + method="IdP", + database='dev', + cluster_id='my-testing-cluster', + credentials_provider='AzureCredentialsProvider', + user='someuser@myazure.org', + password='somepassword', + idp_tenant='my_idp_tenant', + client_id='my_client_id', + client_secret='my_client_secret', + preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' + ) + with self.assertRaises(FailedToConnectError) as context: + connect_method_factory = RedshiftConnectMethodFactory( + self.config.credentials + ) + connect_method_factory.get_connect_method() + self.assertTrue("'region' must be provided" in context.exception.msg) + + @mock.patch("redshift_connector.connect", Mock()) - def test_idp_identity_no_region(self): + def test_idp_identity_no_provider(self): self.config.credentials = self.config.credentials.replace( method="IdP", ) @@ -250,7 +273,6 @@ def test_okta_identity_plugin(self): app_id='my_first_appetizer', app_name='dinner_party', region='us-east-1' - ) connection = self.adapter.acquire_connection("dummy") connection.handle @@ -304,7 +326,6 @@ def test_auth_profile_connect_success(self): timeout=30, ) - @mock.patch("redshift_connector.connect", Mock()) def test_auth_profile_no_profile(self): self.config.credentials = self.config.credentials.replace( From b1b09cebb7104ff2f2b76b6438e62c962ab18af8 Mon Sep 17 00:00:00 2001 From: jiezhec Date: Wed, 22 Feb 2023 10:17:23 -0800 Subject: [PATCH 22/27] remove application_name and add okta prefix --- dbt/adapters/redshift/connections.py | 23 +++++++++--------- tests/unit/test_redshift_adapter.py | 36 +++++++++++----------------- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index b9c4ab7fc..54d4ffe4d 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -63,7 +63,6 @@ class RedshiftCredentials(Credentials): connect_timeout: int = 30 role: Optional[str] = None sslmode: Optional[str] = None - application_name: Optional[str] = "dbt" retries: int = 1 auth_profile: Optional[str] = None # Azure identity provider plugin @@ -73,9 +72,9 @@ class RedshiftCredentials(Credentials): client_secret: Optional[str] = None preferred_role: Optional[str] = None # Okta identity provider plugin - idp_host: Optional[str] = None - app_id: Optional[str] = None - app_name: Optional[str] = None + okta_idp_host: Optional[str] = None + okta_app_id: Optional[str] = None + okta_app_name: Optional[str] = None _ALIASES = {"dbname": "database", "pass": "password"} @@ -116,7 +115,7 @@ def get_connect_method(self): "auto_create": self.credentials.autocreate, "db_groups": self.credentials.db_groups, "timeout": self.credentials.connect_timeout, - "application_name": self.credentials.application_name, + "application_name": str("dbt"), } if method == RedshiftConnectionMethod.IAM or method == RedshiftConnectionMethod.DATABASE: kwargs["host"] = self.credentials.host @@ -220,12 +219,12 @@ def connect(): return connect elif self.credentials.credentials_provider == "OktaCredentialsProvider": if ( - not self.credentials.idp_host - or not self.credentials.app_id - or not self.credentials.app_name + not self.credentials.okta_idp_host + or not self.credentials.okta_app_id + or not self.credentials.okta_app_name ): raise dbt.exceptions.FailedToConnectError( - "Failed to use Okta credential. 'idp_host', 'app_id', 'app_name' must be provided." + "Failed to use Okta credential. 'okta_idp_host', 'okta_app_id', 'okta_app_name' must be provided." ) def connect(): @@ -238,9 +237,9 @@ def connect(): credentials_provider="OktaCredentialsProvider", user=self.credentials.user, password=self.credentials.password, - idp_host=self.credentials.idp_host, - app_id=self.credentials.app_id, - app_name=self.credentials.app_name, + idp_host=self.credentials.okta_idp_host, + app_id=self.credentials.okta_app_id, + app_name=self.credentials.okta_app_name, ) return c diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index e6e195503..c1051332d 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -74,7 +74,6 @@ def test_implicit_database_conn(self): port=5439, auto_create=False, db_groups=[], - application_name="dbt", timeout=30, region="us-east-1", ) @@ -94,7 +93,6 @@ def test_explicit_database_conn(self): auto_create=False, db_groups=[], region="us-east-1", - application_name="dbt", timeout=30, ) @@ -119,7 +117,6 @@ def test_explicit_iam_conn_without_profile(self): auto_create=False, db_groups=[], profile=None, - application_name="dbt", timeout=30, port=5439, ) @@ -148,7 +145,6 @@ def test_explicit_iam_conn_with_profile(self): password="", user="", profile="test", - application_name="dbt", timeout=30, port=5439, ) @@ -175,7 +171,6 @@ def test_explicit_iam_serverless_with_profile(self): password="", user="", profile="test", - application_name="dbt", timeout=30, port=5439, ) @@ -205,7 +200,7 @@ def test_azure_identity_plugin(self): idp_tenant='my_idp_tenant', client_id='my_client_id', client_secret='my_client_secret', - region = 'us-east-1', + region='us-east-1', preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' ) connection = self.adapter.acquire_connection("dummy") @@ -246,7 +241,6 @@ def test_azure_no_region(self): connect_method_factory.get_connect_method() self.assertTrue("'region' must be provided" in context.exception.msg) - @mock.patch("redshift_connector.connect", Mock()) def test_idp_identity_no_provider(self): self.config.credentials = self.config.credentials.replace( @@ -269,9 +263,9 @@ def test_okta_identity_plugin(self): credentials_provider='OktaCredentialsProvider', user='someuser@myazure.org', password='somepassword', - idp_host='my_idp_host', - app_id='my_first_appetizer', - app_name='dinner_party', + okta_idp_host='my_idp_host', + okta_app_id='my_first_appetizer', + okta_app_name='dinner_party', region='us-east-1' ) connection = self.adapter.acquire_connection("dummy") @@ -283,9 +277,9 @@ def test_okta_identity_plugin(self): credentials_provider='OktaCredentialsProvider', user='someuser@myazure.org', password='somepassword', - idp_host='my_idp_host', - app_id='my_first_appetizer', - app_name='dinner_party', + okta_idp_host='my_idp_host', + okta_app_id='my_first_appetizer', + okta_app_name='dinner_party', region='us-east-1' ) @@ -319,7 +313,6 @@ def test_auth_profile_connect_success(self): port=5439, auto_create=False, db_groups=[], - application_name="dbt", database="", region="us-east-1", host="", @@ -340,7 +333,6 @@ def test_auth_profile_no_profile(self): self.assertTrue("'auth_profile' must be provided" in context.exception.msg) def test_iam_conn_optionals(self): - profile_cfg = { "outputs": { "test": { @@ -458,11 +450,11 @@ def test_execute_with_fetch(self): cursor, ) # when mock_add_query is called, it will always return None, cursor with mock.patch.object( - self.adapter.connections, "get_response" + self.adapter.connections, "get_response" ) as mock_get_response: mock_get_response.return_value = None with mock.patch.object( - self.adapter.connections, "get_result_from_cursor" + self.adapter.connections, "get_result_from_cursor" ) as mock_get_result_from_cursor: mock_get_result_from_cursor.return_value = table self.adapter.connections.execute( @@ -480,11 +472,11 @@ def test_execute_without_fetch(self): cursor, ) # when mock_add_query is called, it will always return None, cursor with mock.patch.object( - self.adapter.connections, "get_response" + self.adapter.connections, "get_response" ) as mock_get_response: mock_get_response.return_value = None with mock.patch.object( - self.adapter.connections, "get_result_from_cursor" + self.adapter.connections, "get_result_from_cursor" ) as mock_get_result_from_cursor: self.adapter.connections.execute( sql="select * from test2", fetch=False @@ -495,11 +487,11 @@ def test_execute_without_fetch(self): def test_add_query_with_no_cursor(self): with mock.patch.object( - self.adapter.connections, "get_thread_connection" + self.adapter.connections, "get_thread_connection" ) as mock_get_thread_connection: mock_get_thread_connection.return_value = None with self.assertRaisesRegex( - dbt.exceptions.DbtRuntimeError, "Tried to run invalid SQL: on " + dbt.exceptions.DbtRuntimeError, "Tried to run invalid SQL: on " ): self.adapter.connections.add_query(sql="") mock_get_thread_connection.assert_called_once() @@ -507,7 +499,7 @@ def test_add_query_with_no_cursor(self): def test_add_query_success(self): cursor = mock.Mock() with mock.patch.object( - dbt.adapters.redshift.connections.SQLConnectionManager, "add_query" + dbt.adapters.redshift.connections.SQLConnectionManager, "add_query" ) as mock_add_query: mock_add_query.return_value = None, cursor self.adapter.connections.add_query("select * from test3") From 674d754a717e24f8e8cb21453fa1e1f0d1bb2a0e Mon Sep 17 00:00:00 2001 From: jiezhec Date: Wed, 22 Feb 2023 11:08:06 -0800 Subject: [PATCH 23/27] add azure_ and okta_ prefix, make application_name unconfigurable to user --- dbt/adapters/redshift/connections.py | 30 ++++++++++++++-------------- tests/unit/test_redshift_adapter.py | 28 ++++++++++++++++---------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 54d4ffe4d..01ebe5481 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -67,10 +67,10 @@ class RedshiftCredentials(Credentials): auth_profile: Optional[str] = None # Azure identity provider plugin credentials_provider: Optional[str] = None - idp_tenant: Optional[str] = None - client_id: Optional[str] = None - client_secret: Optional[str] = None - preferred_role: Optional[str] = None + azure_idp_tenant: Optional[str] = None + azure_client_id: Optional[str] = None + azure_client_secret: Optional[str] = None + azure_preferred_role: Optional[str] = None # Okta identity provider plugin okta_idp_host: Optional[str] = None okta_app_id: Optional[str] = None @@ -115,7 +115,7 @@ def get_connect_method(self): "auto_create": self.credentials.autocreate, "db_groups": self.credentials.db_groups, "timeout": self.credentials.connect_timeout, - "application_name": str("dbt"), + "application_name": "dbt", } if method == RedshiftConnectionMethod.IAM or method == RedshiftConnectionMethod.DATABASE: kwargs["host"] = self.credentials.host @@ -189,14 +189,14 @@ def connect(): if self.credentials.credentials_provider == "AzureCredentialsProvider": if ( - not self.credentials.idp_tenant - or not self.credentials.client_id - or not self.credentials.client_secret - or not self.credentials.preferred_role + not self.credentials.azure_idp_tenant + or not self.credentials.azure_client_id + or not self.credentials.azure_client_secret + or not self.credentials.azure_preferred_role ): raise dbt.exceptions.FailedToConnectError( - "Failed to use Azure credential. 'idp_tenant', 'client_id', 'client_secret', " - "and 'preferred_role' must be provided" + "Failed to use Azure credential. 'azure_idp_tenant', 'azure_client_id', 'azure_client_secret', " + "and 'azure_preferred_role' must be provided" ) def connect(): @@ -209,10 +209,10 @@ def connect(): credentials_provider="AzureCredentialsProvider", user=self.credentials.user, password=self.credentials.password, - idp_tenant=self.credentials.idp_tenant, - client_id=self.credentials.client_id, - client_secret=self.credentials.client_secret, - preferred_role=self.credentials.preferred_role, + idp_tenant=self.credentials.azure_idp_tenant, + client_id=self.credentials.azure_client_id, + client_secret=self.credentials.azure_client_secret, + preferred_role=self.credentials.azure_preferred_role, ) return c diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index c1051332d..316e78ef0 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -75,6 +75,7 @@ def test_implicit_database_conn(self): auto_create=False, db_groups=[], timeout=30, + application_name="dbt", region="us-east-1", ) @@ -93,6 +94,7 @@ def test_explicit_database_conn(self): auto_create=False, db_groups=[], region="us-east-1", + application_name="dbt", timeout=30, ) @@ -118,6 +120,7 @@ def test_explicit_iam_conn_without_profile(self): db_groups=[], profile=None, timeout=30, + application_name="dbt", port=5439, ) @@ -145,6 +148,7 @@ def test_explicit_iam_conn_with_profile(self): password="", user="", profile="test", + application_name="dbt", timeout=30, port=5439, ) @@ -171,6 +175,7 @@ def test_explicit_iam_serverless_with_profile(self): password="", user="", profile="test", + application_name="dbt", timeout=30, port=5439, ) @@ -197,11 +202,11 @@ def test_azure_identity_plugin(self): credentials_provider='AzureCredentialsProvider', user='someuser@myazure.org', password='somepassword', - idp_tenant='my_idp_tenant', - client_id='my_client_id', - client_secret='my_client_secret', + azure_idp_tenant='my_idp_tenant', + azure_client_id='my_client_id', + azure_client_secret='my_client_secret', region='us-east-1', - preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' + azure_preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' ) connection = self.adapter.acquire_connection("dummy") connection.handle @@ -229,10 +234,10 @@ def test_azure_no_region(self): credentials_provider='AzureCredentialsProvider', user='someuser@myazure.org', password='somepassword', - idp_tenant='my_idp_tenant', - client_id='my_client_id', - client_secret='my_client_secret', - preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' + azure_idp_tenant='my_idp_tenant', + azure_client_id='my_client_id', + azure_client_secret='my_client_secret', + azure_preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' ) with self.assertRaises(FailedToConnectError) as context: connect_method_factory = RedshiftConnectMethodFactory( @@ -277,9 +282,9 @@ def test_okta_identity_plugin(self): credentials_provider='OktaCredentialsProvider', user='someuser@myazure.org', password='somepassword', - okta_idp_host='my_idp_host', - okta_app_id='my_first_appetizer', - okta_app_name='dinner_party', + idp_host='my_idp_host', + app_id='my_first_appetizer', + app_name='dinner_party', region='us-east-1' ) @@ -315,6 +320,7 @@ def test_auth_profile_connect_success(self): db_groups=[], database="", region="us-east-1", + application_name="dbt", host="", timeout=30, ) From c901bb757e89de1ba198364da121c3d318612e0f Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Tue, 21 Mar 2023 10:14:02 -0700 Subject: [PATCH 24/27] error message enhancement --- dbt/adapters/redshift/connections.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 82e1db271..e00da2528 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -225,7 +225,8 @@ def connect(): or not self.credentials.okta_app_name ): raise dbt.exceptions.FailedToConnectError( - "Failed to use Okta credential. 'okta_idp_host', 'okta_app_id', 'okta_app_name' must be provided." + "Failed to use Okta credential. 'okta_idp_host', 'okta_app_id', and " + "'okta_app_name' must be provided." ) def connect(): From 0d89af6a49e4982381a6e486bc83e2a0b9869f1d Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Tue, 21 Mar 2023 11:05:01 -0700 Subject: [PATCH 25/27] error handling for region mismatch, make credentials providers case insensitive --- dbt/adapters/redshift/connections.py | 35 +++++++++---- tests/unit/test_redshift_adapter.py | 75 +++++++++++++++++++++++++++- 2 files changed, 97 insertions(+), 13 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index e00da2528..e0a568aed 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -2,7 +2,7 @@ import re from multiprocessing import Lock from contextlib import contextmanager -from typing import NewType, Tuple, Union +from typing import NewType, Tuple import agate import sqlparse @@ -19,7 +19,6 @@ from dbt.helper_types import Port from redshift_connector import OperationalError, DatabaseError, DataError -from redshift_connector.utils.oids import get_datatype_name logger = AdapterLogger("Redshift") @@ -109,7 +108,7 @@ def __init__(self, credentials): def get_connect_method(self): method = self.credentials.method kwargs = { - "host": "", + "host": None, "region": self.credentials.region, "database": self.credentials.database, "port": self.credentials.port if self.credentials.port else 5439, @@ -123,6 +122,11 @@ def get_connect_method(self): kwargs["region"] = self.credentials.host.split(".")[2] if self.credentials.sslmode: kwargs["sslmode"] = self.credentials.sslmode + if self.credentials.host and self.credentials.region: + if self.credentials.host.split(".")[2] != self.credentials.region: + raise dbt.exceptions.FailedToConnectError( + "'region' provided in profiles.yml does not match with region of 'host'." + ) # Support missing 'method' for backwards compatibility if method == RedshiftConnectionMethod.DATABASE or method is None: @@ -188,7 +192,19 @@ def connect(): "Failed to use IdP credentials. 'password' and 'user' must be provided." ) - if self.credentials.credentials_provider == "AzureCredentialsProvider": + if ( + (self.credentials.credentials_provider.lower() != "azurecredentialsprovider") + and (self.credentials.credentials_provider.lower() != "azure") + and (self.credentials.credentials_provider.lower() != "okta") + and (self.credentials.credentials_provider.lower() != "oktacredentialsprovider") + ): + raise dbt.exceptions.FailedToConnectError( + "Unrecognized credentials provider. Enter 'azure' or 'okta' for 'credentials_provider'." + ) + + if (self.credentials.credentials_provider.lower() == "azurecredentialsprovider") or ( + self.credentials.credentials_provider.lower() == "azure" + ): if ( not self.credentials.azure_idp_tenant or not self.credentials.azure_client_id @@ -218,15 +234,16 @@ def connect(): return c return connect - elif self.credentials.credentials_provider == "OktaCredentialsProvider": + elif (self.credentials.credentials_provider.lower() == "oktacredentialsprovider") or ( + self.credentials.credentials_provider.lower() == "okta" + ): if ( not self.credentials.okta_idp_host or not self.credentials.okta_app_id or not self.credentials.okta_app_name ): raise dbt.exceptions.FailedToConnectError( - "Failed to use Okta credential. 'okta_idp_host', 'okta_app_id', and " - "'okta_app_name' must be provided." + "Failed to use Okta credential. 'okta_idp_host', 'okta_app_id', 'okta_app_name' must be provided." ) def connect(): @@ -410,7 +427,3 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): @classmethod def get_credentials(cls, credentials): return credentials - - @classmethod - def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: - return get_datatype_name(type_code) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 316e78ef0..f364c51aa 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -79,6 +79,19 @@ def test_implicit_database_conn(self): region="us-east-1", ) + @mock.patch("redshift_connector.connect", Mock()) + def test_region_not_match(self): + self.config.credentials = self.config.credentials.replace( + method="iam", + region="someregion", + ) + with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: + connect_method_factory = RedshiftConnectMethodFactory( + self.config.credentials + ) + connect_method_factory.get_connect_method() + self.assertTrue("does not match with region" in context.exception.msg) + @mock.patch("redshift_connector.connect", Mock()) def test_explicit_database_conn(self): self.config.method = "database" @@ -206,7 +219,8 @@ def test_azure_identity_plugin(self): azure_client_id='my_client_id', azure_client_secret='my_client_secret', region='us-east-1', - azure_preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' + azure_preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll', + host=None ) connection = self.adapter.acquire_connection("dummy") connection.handle @@ -271,8 +285,63 @@ def test_okta_identity_plugin(self): okta_idp_host='my_idp_host', okta_app_id='my_first_appetizer', okta_app_name='dinner_party', + region='us-east-1', + host=None + ) + connection = self.adapter.acquire_connection("dummy") + connection.handle + redshift_connector.connect.assert_called_once_with( + iam=True, + database='dev', + cluster_identifier='my-testing-cluster', + credentials_provider='OktaCredentialsProvider', + user='someuser@myazure.org', + password='somepassword', + idp_host='my_idp_host', + app_id='my_first_appetizer', + app_name='dinner_party', region='us-east-1' ) + + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) + def test_idp_spelling_error(self): + self.config.credentials = self.config.credentials.replace( + method="IdP", + database='dev', + cluster_id='my-testing-cluster', + credentials_provider='oktaar', + user='someuser@myazure.org', + password='somepassword', + okta_idp_host='my_idp_host', + okta_app_id='my_first_appetizer', + okta_app_name='dinner_party', + region='us-east-1', + host=None + ) + with self.assertRaises(FailedToConnectError) as context: + connect_method_factory = RedshiftConnectMethodFactory( + self.config.credentials + ) + connect_method_factory.get_connect_method() + self.assertTrue("Unrecognized" in context.exception.msg) + + @mock.patch("redshift_connector.connect", Mock()) + @mock.patch("boto3.Session", Mock()) + def test_okta_case_sensitivity(self): + self.config.credentials = self.config.credentials.replace( + method="IdP", + database='dev', + cluster_id='my-testing-cluster', + credentials_provider='oKtA', + user='someuser@myazure.org', + password='somepassword', + okta_idp_host='my_idp_host', + okta_app_id='my_first_appetizer', + okta_app_name='dinner_party', + region='us-east-1', + host=None + ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( @@ -288,6 +357,7 @@ def test_okta_identity_plugin(self): region='us-east-1' ) + @mock.patch("redshift_connector.connect", Mock()) @mock.patch.dict( os.environ, @@ -305,6 +375,7 @@ def test_auth_profile_connect_success(self): database="", user="", region="us-east-1", + host=None ) connection = self.adapter.acquire_connection("dummy") connection.handle @@ -321,7 +392,7 @@ def test_auth_profile_connect_success(self): database="", region="us-east-1", application_name="dbt", - host="", + host=None, timeout=30, ) From 6eb89e4e26c56713823daa2aa253900e4fecc44c Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Mon, 17 Apr 2023 13:59:10 -0700 Subject: [PATCH 26/27] remove unused code --- dbt/adapters/redshift/connections.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dbt/adapters/redshift/connections.py b/dbt/adapters/redshift/connections.py index 53da40daa..aebcbdac9 100644 --- a/dbt/adapters/redshift/connections.py +++ b/dbt/adapters/redshift/connections.py @@ -2,14 +2,13 @@ import re from multiprocessing import Lock from contextlib import contextmanager -from typing import NewType, Tuple, Union, Optional, List +from typing import NewType, Tuple, Optional, List from dataclasses import dataclass, field import agate import sqlparse import redshift_connector -from redshift_connector.utils.oids import get_datatype_name from dbt.adapters.sql import SQLConnectionManager from dbt.contracts.connection import AdapterResponse, Connection, Credentials @@ -18,8 +17,6 @@ import dbt.flags from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum from dbt.helper_types import Port -from redshift_connector import OperationalError, DatabaseError, DataError - logger = AdapterLogger("Redshift") @@ -283,6 +280,7 @@ def connect(): if self.credentials.role: c.cursor().execute("set role {}".format(self.credentials.role)) return c + else: raise dbt.exceptions.FailedToConnectError( "Invalid 'method' in profile: '{}'".format(method) From de7e0e78a109fb57b4eb3e7bb4d58f7449b9ac6e Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Mon, 17 Apr 2023 14:10:40 -0700 Subject: [PATCH 27/27] add pre-commit hook changes --- tests/unit/test_redshift_adapter.py | 186 +++++++++++++--------------- 1 file changed, 85 insertions(+), 101 deletions(-) diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index 2a683340c..d011b0367 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -85,9 +85,7 @@ def test_region_not_match(self): region="someregion", ) with self.assertRaises(dbt.exceptions.FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory( - self.config.credentials - ) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("does not match with region" in context.exception.msg) @@ -207,31 +205,31 @@ def test_serverless_iam_failure(self): def test_azure_identity_plugin(self): self.config.credentials = self.config.credentials.replace( method="IdP", - database='dev', - cluster_id='my-testing-cluster', - credentials_provider='AzureCredentialsProvider', - user='someuser@myazure.org', - password='somepassword', - azure_idp_tenant='my_idp_tenant', - azure_client_id='my_client_id', - azure_client_secret='my_client_secret', - region='us-east-1', - azure_preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll', - host=None + database="dev", + cluster_id="my-testing-cluster", + credentials_provider="AzureCredentialsProvider", + user="someuser@myazure.org", + password="somepassword", + azure_idp_tenant="my_idp_tenant", + azure_client_id="my_client_id", + azure_client_secret="my_client_secret", + region="us-east-1", + azure_preferred_role="arn:aws:iam:123:role/MyFirstDinnerRoll", + host=None, ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - database='dev', - cluster_identifier='my-testing-cluster', - credentials_provider='AzureCredentialsProvider', - user='someuser@myazure.org', - password='somepassword', - idp_tenant='my_idp_tenant', - client_id='my_client_id', - client_secret='my_client_secret', - preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll', + database="dev", + cluster_identifier="my-testing-cluster", + credentials_provider="AzureCredentialsProvider", + user="someuser@myazure.org", + password="somepassword", + idp_tenant="my_idp_tenant", + client_id="my_client_id", + client_secret="my_client_secret", + preferred_role="arn:aws:iam:123:role/MyFirstDinnerRoll", region="us-east-1", ) @@ -240,20 +238,18 @@ def test_azure_identity_plugin(self): def test_azure_no_region(self): self.config.credentials = self.config.credentials.replace( method="IdP", - database='dev', - cluster_id='my-testing-cluster', - credentials_provider='AzureCredentialsProvider', - user='someuser@myazure.org', - password='somepassword', - azure_idp_tenant='my_idp_tenant', - azure_client_id='my_client_id', - azure_client_secret='my_client_secret', - azure_preferred_role='arn:aws:iam:123:role/MyFirstDinnerRoll' + database="dev", + cluster_id="my-testing-cluster", + credentials_provider="AzureCredentialsProvider", + user="someuser@myazure.org", + password="somepassword", + azure_idp_tenant="my_idp_tenant", + azure_client_id="my_client_id", + azure_client_secret="my_client_secret", + azure_preferred_role="arn:aws:iam:123:role/MyFirstDinnerRoll", ) with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory( - self.config.credentials - ) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("'region' must be provided" in context.exception.msg) @@ -263,9 +259,7 @@ def test_idp_identity_no_provider(self): method="IdP", ) with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory( - self.config.credentials - ) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("credentials_provider" in context.exception.msg) @@ -274,30 +268,30 @@ def test_idp_identity_no_provider(self): def test_okta_identity_plugin(self): self.config.credentials = self.config.credentials.replace( method="IdP", - database='dev', - cluster_id='my-testing-cluster', - credentials_provider='OktaCredentialsProvider', - user='someuser@myazure.org', - password='somepassword', - okta_idp_host='my_idp_host', - okta_app_id='my_first_appetizer', - okta_app_name='dinner_party', - region='us-east-1', - host=None + database="dev", + cluster_id="my-testing-cluster", + credentials_provider="OktaCredentialsProvider", + user="someuser@myazure.org", + password="somepassword", + okta_idp_host="my_idp_host", + okta_app_id="my_first_appetizer", + okta_app_name="dinner_party", + region="us-east-1", + host=None, ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - database='dev', - cluster_identifier='my-testing-cluster', - credentials_provider='OktaCredentialsProvider', - user='someuser@myazure.org', - password='somepassword', - idp_host='my_idp_host', - app_id='my_first_appetizer', - app_name='dinner_party', - region='us-east-1' + database="dev", + cluster_identifier="my-testing-cluster", + credentials_provider="OktaCredentialsProvider", + user="someuser@myazure.org", + password="somepassword", + idp_host="my_idp_host", + app_id="my_first_appetizer", + app_name="dinner_party", + region="us-east-1", ) @mock.patch("redshift_connector.connect", Mock()) @@ -305,21 +299,19 @@ def test_okta_identity_plugin(self): def test_idp_spelling_error(self): self.config.credentials = self.config.credentials.replace( method="IdP", - database='dev', - cluster_id='my-testing-cluster', - credentials_provider='oktaar', - user='someuser@myazure.org', - password='somepassword', - okta_idp_host='my_idp_host', - okta_app_id='my_first_appetizer', - okta_app_name='dinner_party', - region='us-east-1', - host=None + database="dev", + cluster_id="my-testing-cluster", + credentials_provider="oktaar", + user="someuser@myazure.org", + password="somepassword", + okta_idp_host="my_idp_host", + okta_app_id="my_first_appetizer", + okta_app_name="dinner_party", + region="us-east-1", + host=None, ) with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory( - self.config.credentials - ) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("Unrecognized" in context.exception.msg) @@ -328,33 +320,32 @@ def test_idp_spelling_error(self): def test_okta_case_sensitivity(self): self.config.credentials = self.config.credentials.replace( method="IdP", - database='dev', - cluster_id='my-testing-cluster', - credentials_provider='oKtA', - user='someuser@myazure.org', - password='somepassword', - okta_idp_host='my_idp_host', - okta_app_id='my_first_appetizer', - okta_app_name='dinner_party', - region='us-east-1', - host=None + database="dev", + cluster_id="my-testing-cluster", + credentials_provider="oKtA", + user="someuser@myazure.org", + password="somepassword", + okta_idp_host="my_idp_host", + okta_app_id="my_first_appetizer", + okta_app_name="dinner_party", + region="us-east-1", + host=None, ) connection = self.adapter.acquire_connection("dummy") connection.handle redshift_connector.connect.assert_called_once_with( iam=True, - database='dev', - cluster_identifier='my-testing-cluster', - credentials_provider='OktaCredentialsProvider', - user='someuser@myazure.org', - password='somepassword', - idp_host='my_idp_host', - app_id='my_first_appetizer', - app_name='dinner_party', - region='us-east-1' + database="dev", + cluster_identifier="my-testing-cluster", + credentials_provider="OktaCredentialsProvider", + user="someuser@myazure.org", + password="somepassword", + idp_host="my_idp_host", + app_id="my_first_appetizer", + app_name="dinner_party", + region="us-east-1", ) - @mock.patch("redshift_connector.connect", Mock()) @mock.patch.dict( os.environ, @@ -372,7 +363,7 @@ def test_auth_profile_connect_success(self): database="", user="", region="us-east-1", - host=None + host=None, ) connection = self.adapter.acquire_connection("dummy") connection.handle @@ -400,13 +391,10 @@ def test_auth_profile_no_profile(self): auth_profile="", ) with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory( - self.config.credentials - ) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("'auth_profile' must be provided" in context.exception.msg) - def test_iam_conn_optionals(self): profile_cfg = { "outputs": { @@ -432,18 +420,14 @@ def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate self.config.credentials.method = "badmethod" with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory( - self.config.credentials - ) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("badmethod" in context.exception.msg) def test_invalid_iam_no_cluster_id(self): self.config.credentials = self.config.credentials.replace(method="iam") with self.assertRaises(FailedToConnectError) as context: - connect_method_factory = RedshiftConnectMethodFactory( - self.config.credentials - ) + connect_method_factory = RedshiftConnectMethodFactory(self.config.credentials) connect_method_factory.get_connect_method() self.assertTrue("'cluster_id' must be provided" in context.exception.msg)