Skip to content

Commit

Permalink
Allow env var templating in config
Browse files Browse the repository at this point in the history
  • Loading branch information
mooster531 committed Nov 27, 2024
1 parent 2f43f79 commit 14c9b1a
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 18 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ The configuration file consists of three main sections:
- `data_sources`: Defines available databases
- `jobs`: Defines synchronization jobs that connect sources to destinations

The config file may contain environment variable placeholders in
[envsubst](https://www.gnu.org/software/gettext/manual/html_node/envsubst-Invocation.html)-compatible format:
- `$VAR_NAME`
- `${VAR_NAME}`
- `$varname`

**Note**: Every variable referenced this way __must__ be defined at runtime,
otherwise the program will exit with an error.

#### Data Source Definitions

Sources are defined as a list of configurations, each containing:
Expand Down
16 changes: 12 additions & 4 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any
from typing import Any, TextIO

import yaml
from dotenv import load_dotenv
Expand Down Expand Up @@ -131,7 +131,15 @@ class RuntimeConfig:
jobs: list[Job]

@classmethod
def load_from_yaml(cls, file_path: Path | str = "config.yaml") -> RuntimeConfig:
def read_yaml(cls, file_handle: TextIO) -> Any:
"""Load YAML from text, substituting any environment variables referenced."""
Env.load()
text = str(file_handle.read())
text = Env.interpolate(text)
return yaml.safe_load(text)

@classmethod
def load(cls, file_path: Path | str = "config.yaml") -> RuntimeConfig:
"""Load and parse a YAML configuration file.
Args:
Expand All @@ -146,8 +154,8 @@ def load_from_yaml(cls, file_path: Path | str = "config.yaml") -> RuntimeConfig:
ValueError: If the configuration contains invalid database types
"""
with open(file_path, "rb") as _handle:
data = yaml.safe_load(_handle)
with open(file_path, encoding="utf-8") as _handle:
data = cls.read_yaml(_handle)

# Load data sources map
sources = {}
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def main() -> None:
)
args = parser.parse_args()

config = RuntimeConfig.load_from_yaml(args.config)
config = RuntimeConfig.load(args.config)

tasks = [job.run() for job in config.jobs]
for job, completed_task in zip(
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,21 @@ async def test_dune_to_local_job_run(self, mock_env, mock_dune_client):

# everything is okay
mock_dune_client.return_value = good_client
conf = RuntimeConfig.load_from_yaml(config_root / "dune_to_postgres.yaml")
conf = RuntimeConfig.load(config_root / "dune_to_postgres.yaml")
await conf.jobs[0].run()

mock_dune_client.reset_mock()

# Dune returned a None result
mock_dune_client.return_value = bad_client_returned_none
conf = RuntimeConfig.load_from_yaml(config_root / "dune_to_postgres.yaml")
conf = RuntimeConfig.load(config_root / "dune_to_postgres.yaml")
with self.assertRaises(ValueError):
await conf.jobs[0].run()

# Dune returned an empty result
mock_dune_client.reset_mock()
mock_dune_client.return_value = empty_result_client
conf = RuntimeConfig.load_from_yaml(config_root / "dune_to_postgres.yaml")
conf = RuntimeConfig.load(config_root / "dune_to_postgres.yaml")
with self.assertLogs(level=WARNING) as logs:
await conf.jobs[0].run()

Expand Down
18 changes: 18 additions & 0 deletions tests/fixtures/config/basic_with_env_missing_vars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
data_sources:
- name: dune
type: dune
key: ${DUNE_API_KEY}
- name: postgres
type: postgres
key: ${DB_URL}

jobs:
- name: Some job
source:
ref: postgres
table_name: foo_table
query_string: SELECT 1;
destination:
ref: dune
table_name: $UNDEFINED_VAR
39 changes: 39 additions & 0 deletions tests/fixtures/config/basic_with_env_placeholders.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
---
data_sources:
- name: dune
type: dune
key: ${DUNE_API_KEY}
- name: postgres
type: postgres
key: ${DB_URL}

jobs:
- name: Download simple test query to local postgres
source:
ref: dune
query_id: $Query_ID
query_engine: medium
poll_frequency: ${POLL_FREQUENCY_DUNE_PG}
parameters:
- name: blockchain
value: $BLOCKCHAIN_NAME
type: ENUM
- name: blocktime
value: 2024-09-01 00:00:00
type: DATE
- name: result_limit
value: 10
type: NUMBER
destination:
ref: postgres
table_name: parameterized_results_4238114
if_exists: $WHAT_IF_EXISTS

- name: Some other job
source:
ref: postgres
table_name: foo_table
query_string: SELECT 1;
destination:
ref: dune
table_name: $table_name
43 changes: 35 additions & 8 deletions tests/unit/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def setUpClass(cls):
{
"DUNE_API_KEY": "test_key",
"DB_URL": "postgresql://postgres:postgres@localhost:5432/postgres",
"Query_ID": "123321",
"POLL_FREQUENCY_DUNE_PG": "192",
"BLOCKCHAIN_NAME": "moosis",
"WHAT_IF_EXISTS": "replace",
"table_name": "my_pg_table",
},
clear=True,
)
Expand All @@ -59,32 +64,54 @@ def tearDownClass(cls):

def test_load_basic_conf(self):
config_file = config_root / "basic.yaml"
conf = RuntimeConfig.load_from_yaml(config_file.absolute())
conf = RuntimeConfig.load(config_file.absolute())
self.assertEqual(2, len(conf.jobs))
# TODO: come up with more explicit assertions.
dune_to_pg_job = conf.jobs[0]
pg_to_dune_job = conf.jobs[1]
self.assertEqual("test_key", dune_to_pg_job.source.client.token)
self.assertEqual(
"postgresql://postgres:***@localhost:5432/postgres",
str(pg_to_dune_job.source.engine.url),
)

def test_load_templated_conf(self):
config_file = config_root / "basic_with_env_placeholders.yaml"
conf = RuntimeConfig.load(config_file.absolute())
self.assertEqual(2, len(conf.jobs))
dune_to_pg_job = conf.jobs[0]
pg_to_dune_job = conf.jobs[1]
self.assertEqual(int("123321"), dune_to_pg_job.source.query.query_id)
self.assertEqual(int("192"), dune_to_pg_job.source.poll_frequency)
self.assertEqual("moosis", dune_to_pg_job.source.query.params[0].value)
self.assertEqual("replace", dune_to_pg_job.destination.if_exists)
self.assertEqual("my_pg_table", pg_to_dune_job.destination.table_name)

config_file = config_root / "basic_with_env_missing_vars.yaml"
with self.assertRaises(KeyError):
RuntimeConfig.load(config_file.absolute())

def test_load_unsupported_conf(self):
with self.assertRaises(ValueError) as context:
RuntimeConfig.load_from_yaml(config_root / "unsupported_source.yaml")
RuntimeConfig.load(config_root / "unsupported_source.yaml")
self.assertIn("Unsupported source_db type", str(context.exception))

with self.assertRaises(ValueError) as context:
RuntimeConfig.load_from_yaml(config_root / "unsupported_dest.yaml")
RuntimeConfig.load(config_root / "unsupported_dest.yaml")
self.assertIn("Unsupported destination_db type", str(context.exception))

def test_load_buggy_conf(self):
with self.assertRaises(KeyError) as context:
RuntimeConfig.load_from_yaml(config_root / "buggy.yaml")
RuntimeConfig.load(config_root / "buggy.yaml")
self.assertIn("'table_name'", str(context.exception))

with self.assertRaises(SystemExit):
RuntimeConfig.load_from_yaml(config_root / "unknown_src.yaml")
RuntimeConfig.load(config_root / "unknown_src.yaml")

with self.assertRaises(SystemExit):
RuntimeConfig.load_from_yaml(config_root / "unknown_dest.yaml")
RuntimeConfig.load(config_root / "unknown_dest.yaml")

with self.assertRaises(SystemExit):
RuntimeConfig.load_from_yaml(config_root / "no_data_sources.yaml")
RuntimeConfig.load(config_root / "no_data_sources.yaml")


class TestParseQueryParameters(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def setUpClass(cls):
def test_load_sql_file(self):
os.chdir(fixtures_root)

RuntimeConfig.load_from_yaml(config_root / "sql_file.yaml")
RuntimeConfig.load(config_root / "sql_file.yaml")

# ensure the missing file really is missing
missing_file = fixtures_root / "missing-file.sql"
missing_file.unlink(missing_ok=True)
with self.assertRaises(RuntimeError):
RuntimeConfig.load_from_yaml(config_root / "invalid_sql_file.yaml")
RuntimeConfig.load(config_root / "invalid_sql_file.yaml")

def test_invalid_query_string(self):
with self.assertRaises(ValueError) as context:
Expand Down

0 comments on commit 14c9b1a

Please sign in to comment.