diff --git a/src/args.py b/src/args.py index dc5f989..55bb760 100644 --- a/src/args.py +++ b/src/args.py @@ -15,6 +15,7 @@ class Args: config: Path jobs: list[str] | None + allow_alteration: bool @classmethod def from_command_line(cls) -> Args: @@ -35,8 +36,15 @@ def from_command_line(cls) -> Args: default=None, help="Names of specific jobs to run (default: run all jobs)", ) + parser.add_argument( + "--allow_alteration", + type=bool, + default=True, + help="Allow table alteration based on failed validation (default: True)", + ) args = parser.parse_args() return cls( config=args.config, jobs=args.jobs if args.jobs else None, # Convert empty list to None + allow_alteration=args.allow_alteration, ) diff --git a/src/config.py b/src/config.py index ddb87c3..777fcf7 100644 --- a/src/config.py +++ b/src/config.py @@ -13,7 +13,7 @@ from dune_client.query import QueryBase from src.destinations.dune import DuneDestination -from src.destinations.postgres import PostgresDestination +from src.destinations.postgres import PGDestConfig, PostgresDestination from src.interfaces import Destination, Source from src.job import Database, Job from src.sources.dune import DuneSource, parse_query_parameters @@ -269,7 +269,6 @@ def _build_destination( return PostgresDestination( db_url=dest.key, table_name=dest_config["table_name"], - if_exists=dest_config.get("if_exists", "append"), - index_columns=dest_config.get("index_columns", []), + config=PGDestConfig.from_dict(dest_config), ) raise ValueError(f"Unsupported destination_db type: {dest}") diff --git a/src/destinations/postgres.py b/src/destinations/postgres.py index 24e18fe..41fae4c 100644 --- a/src/destinations/postgres.py +++ b/src/destinations/postgres.py @@ -1,9 +1,12 @@ """Destination logic for PostgreSQL.""" -from typing import Literal +from __future__ import annotations + +import dataclasses +from typing import Any, Literal import sqlalchemy -from sqlalchemy import MetaData, Table, create_engine, inspect +from sqlalchemy import DDL, MetaData, Table, create_engine, inspect from sqlalchemy.dialects.postgresql import insert from src.interfaces import Destination, TypedDataFrame @@ -12,6 +15,34 @@ TableExistsPolicy = Literal["append", "replace", "upsert", "insert_ignore"] +@dataclasses.dataclass +class PGDestConfig: + """Configuration Parameters for PostgreSQL as Destination.""" + + allow_alter: bool = True + # TODO(bh2smith): allow_drop? + if_exists: TableExistsPolicy = "append" + index_columns: list[str] | None = None + + @classmethod + def from_dict(cls, config: dict[str, Any]) -> PGDestConfig: + """Construct PGDestConfig from a dictionary.""" + return cls( + allow_alter=config.get("allow_alter", True), + if_exists=config.get("if_exists", "append"), + index_columns=config.get("index_columns", []), + ) + + @classmethod + def default(cls) -> PGDestConfig: + """Construct PGDestConfig from a dictionary.""" + return cls( + allow_alter=True, + if_exists="append", + index_columns=[], + ) + + class PostgresDestination(Destination[TypedDataFrame]): """A class representing PostgreSQL as a destination for data storage. @@ -42,17 +73,19 @@ def __init__( self, db_url: str, table_name: str, - if_exists: TableExistsPolicy = "append", - index_columns: list[str] | None = None, + config: PGDestConfig | None = None, ): - if index_columns is None: - index_columns = [] + if config is None: + config = PGDestConfig.default() + if config.index_columns is None: + config.index_columns = [] self.engine: sqlalchemy.engine.Engine = create_engine(db_url) + self.allow_alter: bool = config.allow_alter self.table_name: str = table_name - self.if_exists: TableExistsPolicy = if_exists + self.if_exists: TableExistsPolicy = config.if_exists # List of column forming the ON CONFLICT condition. # Only relevant for "upsert" TableExistsPolicy - self.index_columns: list[str] = index_columns + self.index_columns: list[str] = config.index_columns super().__init__() @@ -80,6 +113,23 @@ def validate_unique_constraints(self) -> None: f"ALTER TABLE {table} ADD CONSTRAINT " f"{constraint_name} UNIQUE ({index_columns_str});" ) + if self.allow_alter: + log.info( + "No uniqueness constraint for table %s with column(s) %s. Executing %s", + table, + self.index_columns, + suggestion, + ) + # Define the DDL for the ALTER TABLE suggestion + ddl_statement = DDL(suggestion) + with self.engine.connect() as conn: + # Execute the DDL statement + conn.execute(ddl_statement) + log.info( + "Successfully executed: %s", ddl_statement.compile(self.engine) + ) + return + message = ( "The ON CONFLICT clause requires a unique or exclusion constraint " f"on the column(s): {index_columns_str}. " diff --git a/tests/e2e_test.py b/tests/e2e_test.py index 2821b5e..c0a458b 100644 --- a/tests/e2e_test.py +++ b/tests/e2e_test.py @@ -14,7 +14,7 @@ from sqlalchemy.dialects.postgresql import BYTEA, NUMERIC from src.config import RuntimeConfig -from src.destinations.postgres import PostgresDestination +from src.destinations.postgres import PGDestConfig, PostgresDestination from src.sources.dune import dune_result_to_df from tests import config_root, fixtures_root from tests.db_util import query_pg @@ -115,7 +115,11 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): def test_dune_results_to_db(self): pg = PostgresDestination( - db_url=DB_URL, table_name="test_table", if_exists="replace" + db_url=DB_URL, + table_name="test_table", + config=PGDestConfig( + if_exists="replace", + ), ) df, types = dune_result_to_df(SAMPLE_DUNE_RESULTS.result) diff --git a/tests/unit/destinations_test.py b/tests/unit/destinations_test.py index e89ae3d..ac8a318 100644 --- a/tests/unit/destinations_test.py +++ b/tests/unit/destinations_test.py @@ -8,7 +8,7 @@ from dune_client.models import DuneError from src.destinations.dune import DuneDestination -from src.destinations.postgres import PostgresDestination +from src.destinations.postgres import PGDestConfig, PostgresDestination from tests.db_util import create_table, drop_table, raw_exec, select_star @@ -125,8 +125,10 @@ def test_failed_validation(self): PostgresDestination( db_url=self.db_url, table_name="foo", - if_exists="upsert", - index_columns=[], + config=PGDestConfig( + if_exists="upsert", + index_columns=[], + ), ) self.assertIn( @@ -155,8 +157,10 @@ def test_validate_unique_constraints(self): pg_dest = PostgresDestination( db_url=self.db_url, table_name=table_name, - if_exists="upsert", - index_columns=["id"], + config=PGDestConfig( + if_exists="upsert", + index_columns=["id"], + ), ) drop_table(pg_dest.engine, table_name) # No such table. @@ -214,8 +218,10 @@ def test_upsert(self): pg_dest = PostgresDestination( db_url=self.db_url, table_name=table_name, - if_exists="upsert", - index_columns=["id"], + config=PGDestConfig( + if_exists="upsert", + index_columns=["id"], + ), ) df1 = pd.DataFrame({"id": [1], "value": ["alice"]}) df2 = pd.DataFrame({"id": [2], "value": ["bob"]}) @@ -270,8 +276,10 @@ def test_insert_ignore(self): pg_dest = PostgresDestination( db_url=self.db_url, table_name=table_name, - if_exists="insert_ignore", - index_columns=["id"], + config=PGDestConfig( + if_exists="insert_ignore", + index_columns=["id"], + ), ) df1 = pd.DataFrame({"id": [1], "value": ["alice"]}) df2 = pd.DataFrame({"id": [2], "value": ["bob"]}) @@ -323,7 +331,9 @@ def test_replace(self): pg_dest = PostgresDestination( db_url=self.db_url, table_name=table_name, - if_exists="replace", + config=PGDestConfig( + if_exists="replace", + ), ) df1 = pd.DataFrame({"id": [1, 2], "value": ["alice", "bob"]})