From d13edfee8a5555cafc6f3779ed47d5a0e62e47bc Mon Sep 17 00:00:00 2001 From: Benjamin Smith Date: Tue, 3 Dec 2024 21:38:48 +0100 Subject: [PATCH] Postgres Destintion Schema --- .pylintrc | 4 +++- config.yaml | 2 +- src/destinations/postgres.py | 38 +++++++++++++++++++++++++++++++++--- src/interfaces.py | 5 ++++- src/main.py | 28 ++++++++++---------------- 5 files changed, 53 insertions(+), 24 deletions(-) diff --git a/.pylintrc b/.pylintrc index 5040b50..e4564d6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,2 +1,4 @@ [MASTER] -disable=fixme,too-few-public-methods \ No newline at end of file +disable=fixme,too-few-public-methods,too-many-positional-arguments +# Maximum number of arguments for function / method +max-args=6 diff --git a/config.yaml b/config.yaml index c185aa5..b0d106e 100644 --- a/config.yaml +++ b/config.yaml @@ -46,5 +46,5 @@ jobs: poll_frequency: 5 destination: ref: PG - table_name: app_data + table_name: cow.solvers if_exists: replace diff --git a/src/destinations/postgres.py b/src/destinations/postgres.py index 24e18fe..f149787 100644 --- a/src/destinations/postgres.py +++ b/src/destinations/postgres.py @@ -42,6 +42,7 @@ def __init__( self, db_url: str, table_name: str, + schema: str = "public", if_exists: TableExistsPolicy = "append", index_columns: list[str] | None = None, ): @@ -49,6 +50,15 @@ def __init__( index_columns = [] self.engine: sqlalchemy.engine.Engine = create_engine(db_url) self.table_name: str = table_name + self.schema: str | None = schema + + # Split table_name if it contains schema + if "." in table_name: + self.schema, self.table_name = table_name.split(".", 1) + else: + self.schema = schema + self.table_name = table_name + self.if_exists: TableExistsPolicy = if_exists # List of column forming the ON CONFLICT condition. # Only relevant for "upsert" TableExistsPolicy @@ -58,6 +68,19 @@ def __init__( def validate(self) -> bool: """Validate the destination setup.""" + # Check if schema exists + inspector = inspect(self.engine) + available_schemas = inspector.get_schema_names() + if self.schema not in available_schemas: + log.error( + "Schema '%s' does not exist. Available schemas: %s\n" + "To create this schema, run the following SQL command:\n" + "CREATE SCHEMA %s;", + self.schema, + ", ".join(available_schemas), + self.schema, + ) + return False if self.if_exists == "upsert" and len(self.index_columns) == 0: log.error("Upsert without index columns.") return False @@ -66,7 +89,9 @@ def validate(self) -> bool: def validate_unique_constraints(self) -> None: """Validate table has unique or exclusion constraint for index columns.""" inspector = inspect(self.engine) - constraints = inspector.get_unique_constraints(self.table_name) + constraints = inspector.get_unique_constraints( + self.table_name, schema=self.schema + ) index_columns_set = set(self.index_columns) for constraint in constraints: @@ -96,7 +121,7 @@ def table_exists(self) -> bool: :return: True if the table exists, False otherwise. """ inspector = inspect(self.engine) - tables = inspector.get_table_names() + tables = inspector.get_table_names(schema=self.schema) return self.table_name in tables def save( @@ -146,6 +171,7 @@ def replace( df.to_sql( self.table_name, connection, + schema=self.schema, if_exists="replace", index=False, dtype=dtypes, @@ -161,6 +187,7 @@ def append( df.to_sql( self.table_name, connection, + schema=self.schema, if_exists="append", index=False, dtype=dtypes, @@ -185,7 +212,12 @@ def insert( columns = df.columns.tolist() metadata = MetaData() - table = Table(self.table_name, metadata, autoload_with=self.engine) + table = Table( + self.table_name, + metadata, + autoload_with=self.engine, + schema=self.schema, + ) statement = insert(table).values(**{col: df[col] for col in columns}) if on_conflict == "update": diff --git a/src/interfaces.py b/src/interfaces.py index 92752eb..12c3668 100644 --- a/src/interfaces.py +++ b/src/interfaces.py @@ -22,7 +22,10 @@ class Validate(ABC): def __init__(self) -> None: if not self.validate(): - raise ValueError(f"Config for {self.__class__.__name__} is invalid") + raise ValueError( + f"Config for {self.__class__.__name__} is invalid. " + "See ERROR log for details." + ) @abstractmethod def validate(self) -> bool: diff --git a/src/main.py b/src/main.py index 256a3dd..7b3669e 100644 --- a/src/main.py +++ b/src/main.py @@ -25,23 +25,22 @@ from src.args import Args from src.config import RuntimeConfig +from src.job import Job -async def main() -> None: - """Load configuration and execute jobs asynchronously. - - The function: - 1. Parses command line arguments - 2. Loads the configuration from the specified config file (defaults to config.yaml) - 3. Executes each configured job - 4. Logs the completion of each job +async def main(jobs: list[Job]) -> None: + """Asynchronously execute a list of jobs. Raises: - FileNotFoundError: If config file is not found - yaml.YAMLError: If config file is invalid Various exceptions depending on job configuration and execution """ + tasks = [job.run() for job in jobs] + for completed_task in asyncio.as_completed(tasks): + await completed_task + + +if __name__ == "__main__": args = Args.from_command_line() config = RuntimeConfig.load(args.config) @@ -51,11 +50,4 @@ async def main() -> None: if args.jobs is not None else config.jobs ) - - tasks = [job.run() for job in jobs_to_run] - for completed_task in asyncio.as_completed(tasks): - await completed_task - - -if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main(jobs_to_run))