Skip to content

Commit

Permalink
review: add setup / teardown for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maximeperrault committed Dec 3, 2024
1 parent e983855 commit 515aee5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 10 deletions.
2 changes: 0 additions & 2 deletions datascience/src/pipeline/flows/amp_cacem.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def load_new_amps(new_amp: pd.DataFrame):
db_name="monitorenv_remote",
logger=prefect.context.get("logger"),
how="append",
table_id_column="id",
df_id_column="id",
)


Expand Down
6 changes: 2 additions & 4 deletions datascience/src/pipeline/flows/regulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,20 @@ def update_regulations(new_regulations: pd.DataFrame):


@task(checkpoint=False)
def load_new_regulations(new_amp: pd.DataFrame):
def load_new_regulations(new_regulations: pd.DataFrame):
"""Load the output of ``extract_new_regulations`` task into ``regulations_cacem``
table.
Args:
new_amp (pd.DataFrame): output of ``extract_new_regulations`` task.
"""
load(
new_amp,
new_regulations,
table_name="regulations_cacem",
schema="public",
db_name="monitorenv_remote",
logger=prefect.context.get("logger"),
how="append",
table_id_column="id",
df_id_column="id",
)


Expand Down
31 changes: 29 additions & 2 deletions datascience/tests/test_pipeline/test_flows/test_amp_cacem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pandas as pd
import prefect
import pytest

from src.pipeline.flows.amp_cacem import load_new_amps, update_amps
from src.pipeline.generic_tasks import delete_rows, load
from src.read_query import read_query


Expand Down Expand Up @@ -122,7 +124,15 @@ def new_amp() -> pd.DataFrame:
)


def test_load_new_amps(old_amp):
def test_load_new_amps(reset_test_data, old_amp):
delete_rows(
table_name="amp_cacem",
schema="public",
db_name="monitorenv_remote",
table_id_column="id",
ids_to_delete=set(old_amp.id),
logger=prefect.context.get("logger"),
)
load_new_amps.run(old_amp)
loaded_amps = read_query(
"monitorenv_remote",
Expand All @@ -136,7 +146,24 @@ def test_load_new_amps(old_amp):
pd.testing.assert_frame_equal(loaded_amps, old_amp)


def test_update_new_amps(new_amp):
def test_update_new_amps(reset_test_data, new_amp, old_amp):
delete_rows(
table_name="amp_cacem",
schema="public",
db_name="monitorenv_remote",
table_id_column="id",
ids_to_delete=set(old_amp.id),
logger=prefect.context.get("logger"),
)
load(
old_amp,
table_name="amp_cacem",
schema="public",
db_name="monitorenv_remote",
logger=prefect.context.get("logger"),
how="append",
)

update_amps.run(new_amp)
updated_amps = read_query(
"monitorenv_remote",
Expand Down
28 changes: 26 additions & 2 deletions datascience/tests/test_pipeline/test_flows/test_regulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,15 @@ def old_regulations() -> pd.DataFrame:
)


def test_load_new_regulations(old_regulations):
def test_load_new_regulations(reset_test_data, old_regulations):
delete_rows(
table_name="regulations_cacem",
schema="public",
db_name="monitorenv_remote",
table_id_column="id",
ids_to_delete=set(old_regulations.id),
logger=prefect.context.get("logger"),
)
load_new_regulations.run(old_regulations)
loaded_regulations = read_query(
"monitorenv_remote",
Expand All @@ -110,7 +118,23 @@ def test_load_new_regulations(old_regulations):
pd.testing.assert_frame_equal(loaded_regulations, old_regulations)


def test_update_new_regulations(new_regulations):
def test_update_new_regulations(reset_test_data, new_regulations, old_regulations):
delete_rows(
table_name="regulations_cacem",
schema="public",
db_name="monitorenv_remote",
table_id_column="id",
ids_to_delete=set(old_regulations.id),
logger=prefect.context.get("logger"),
)
load(
old_regulations,
table_name="regulations_cacem",
schema="public",
db_name="monitorenv_remote",
logger=prefect.context.get("logger"),
how="append",
)
update_regulations.run(new_regulations)
updated_regulations = read_query(
"monitorenv_remote",
Expand Down

0 comments on commit 515aee5

Please sign in to comment.