Skip to content

Commit

Permalink
tech: update amp flow to use update instead of delete all
Browse files Browse the repository at this point in the history
  • Loading branch information
maximeperrault committed Dec 2, 2024
1 parent d8ac43d commit cef4c59
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 231 deletions.
149 changes: 89 additions & 60 deletions datascience/src/pipeline/flows/amp_cacem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import pandas as pd
import prefect
from prefect import Flow, case, task
from sqlalchemy import text

from src.db_config import create_engine
from src.pipeline.generic_tasks import delete_rows, extract, load
from src.pipeline.utils import psql_insert_copy
from src.pipeline.shared_tasks.update_queries import delete_required, insert_required, merge_hashes, select_ids_to_delete, select_ids_to_insert, select_ids_to_update, update_required



@task(checkpoint=False)
Expand Down Expand Up @@ -34,56 +39,6 @@ def extract_remote_hashes() -> pd.DataFrame:
)


@task(checkpoint=False)
def merge_hashes(
local_hashes: pd.DataFrame, remote_hashes: pd.DataFrame
) -> pd.DataFrame:
return pd.merge(local_hashes, remote_hashes, on="id", how="outer")


@task(checkpoint=False)
def select_ids_to_upsert(hashes: pd.DataFrame) -> set:
ids_to_upsert = set(
hashes.loc[
(hashes.cacem_row_hash.notnull())
& (hashes.cacem_row_hash != hashes.monitorenv_row_hash),
"id",
]
)

return ids_to_upsert


@task(checkpoint=False)
def select_ids_to_delete(hashes: pd.DataFrame) -> set:
return set(hashes.loc[hashes.cacem_row_hash.isna(), "id"])


@task(checkpoint=False)
def upsert_required(ids_to_upsert: set) -> bool:
logger = prefect.context.get("logger")
n = len(ids_to_upsert)
if n > 0:
logger.info(f"Found {n} row(s) to add or upsert.")
res = True
else:
logger.info("No row to add or upsert was found.")
res = False
return res


@task(checkpoint=False)
def delete_required(ids_to_delete: set) -> bool:
logger = prefect.context.get("logger")
n = len(ids_to_delete)
if n > 0:
logger.info(f"Found {n} row(s) to delete.")
res = True
else:
logger.info("No row to delete was found.")
res = False
return res


@task(checkpoint=False)
def delete(ids_to_delete: set):
Expand All @@ -107,8 +62,75 @@ def extract_new_amp(ids_to_upsert: set) -> pd.DataFrame:
)



@task(checkpoint=False)
def load_new_amp(new_amp: pd.DataFrame):
def update_amps(new_amps: pd.DataFrame):
"""Load the output of ``extract_rows_to_update`` task into ``amp``
table.
Args:
new_regulations (pd.DataFrame): output of ``extract_rows_to_update`` task.
"""
e = create_engine("monitorenv_remote")
logger = prefect.context.get("logger")

with e.begin() as connection:
logger.info("Creating temporary table")
connection.execute(
text(
"""CREATE TEMP TABLE tmp_amp_cacem(
id serial,
geom public.geometry(MultiPolygon,4326),
mpa_oriname text,
des_desigfr text,
row_hash text,
mpa_type text,
ref_reg text,
url_legicem text)
ON COMMIT DROP;"""
)
)

columns_to_load = [
"id",
"geom",
"mpa_oriname",
"des_desigfr",
"row_hash",
"mpa_type",
"ref_reg",
"url_legicem"
]

logger.info("Loading to temporary table")

new_amps[columns_to_load].to_sql(
"tmp_amp_cacem",
connection,
if_exists="append",
index=False,
method=psql_insert_copy,
)

logger.info(f"Updating amp_cacem from temporary table {len(new_amps)}")
connection.execute(
text(
"""UPDATE amp_cacem amp
SET geom = tmp.geom,
mpa_oriname = tmp.mpa_oriname,
des_desigfr = tmp.des_desigfr,
row_hash = tmp.row_hash,
mpa_type = tmp.mpa_type,
ref_reg = tmp.ref_reg,
url_legicem = tmp.url_legicem
FROM tmp_amp_cacem tmp
where amp.id = tmp.id;
"""
)
)

@task(checkpoint=False)
def load_new_amps(new_amp: pd.DataFrame):
"""Load the output of ``extract_rows_to_upsert`` task into ``amp``
table.
Expand All @@ -121,27 +143,34 @@ def load_new_amp(new_amp: pd.DataFrame):
schema="public",
db_name="monitorenv_remote",
logger=prefect.context.get("logger"),
how="upsert",
how="append",
table_id_column="id",
df_id_column="id",
)


with Flow("import amp cacem") as flow:

local_hashes = extract_local_hashes()
remote_hashes = extract_remote_hashes()
hashes = merge_hashes(local_hashes, remote_hashes)
outer_hashes = merge_hashes(local_hashes, remote_hashes)
inner_merged = merge_hashes(local_hashes, remote_hashes, "inner")

ids_to_delete = select_ids_to_delete(hashes)
ids_to_delete = select_ids_to_delete(outer_hashes)
cond_delete = delete_required(ids_to_delete)
with case(cond_delete, True):
delete(ids_to_delete)

ids_to_upsert = select_ids_to_upsert(hashes)
cond_upsert = upsert_required(ids_to_upsert)
with case(cond_upsert, True):
new_amp = extract_new_amp(ids_to_upsert)
load_new_amp(new_amp)
ids_to_update = select_ids_to_update(inner_merged)
cond_update = update_required(ids_to_update)
with case(cond_update, True):
new_regulations = extract_new_amp(ids_to_update)
update_amps(new_regulations)

ids_to_insert = select_ids_to_insert(outer_hashes)
cond_insert = insert_required(ids_to_insert)
with case(cond_insert, True):
new_regulations = extract_new_amp(ids_to_insert)
load_new_amps(new_regulations)


flow.file_name = Path(__file__).name
90 changes: 33 additions & 57 deletions datascience/src/pipeline/flows/regulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from prefect import Flow, case, task
from sqlalchemy import text
from src.db_config import create_engine
from src.pipeline.generic_tasks import delete_rows, extract
from src.pipeline.generic_tasks import delete_rows, extract, load
from src.pipeline.processing import prepare_df_for_loading
from src.pipeline.shared_tasks.update_queries import delete_required, insert_required, merge_hashes, select_ids_to_delete, select_ids_to_insert, select_ids_to_update, update_required
from src.pipeline.utils import psql_insert_copy
from src.read_query import read_query

Expand Down Expand Up @@ -36,58 +37,6 @@ def extract_remote_hashes() -> pd.DataFrame:
query_filepath="monitorenv/regulations_hashes.sql",
)


@task(checkpoint=False)
def merge_hashes(
local_hashes: pd.DataFrame, remote_hashes: pd.DataFrame
) -> pd.DataFrame:
return pd.merge(local_hashes, remote_hashes, on="id", how="outer")


@task(checkpoint=False)
def select_ids_to_update(hashes: pd.DataFrame) -> set:
ids_to_update = set(
hashes.loc[
(hashes.cacem_row_hash.notnull())
& (hashes.cacem_row_hash != hashes.monitorenv_row_hash),
"id",
]
)

return ids_to_update


@task(checkpoint=False)
def select_ids_to_delete(hashes: pd.DataFrame) -> set:
return set(hashes.loc[hashes.cacem_row_hash.isna(), "id"])


@task(checkpoint=False)
def update_required(ids_to_update: set) -> bool:
logger = prefect.context.get("logger")
n = len(ids_to_update)
if n > 0:
logger.info(f"Found {n} row(s) to add or update.")
res = True
else:
logger.info("No row to add or update was found.")
res = False
return res


@task(checkpoint=False)
def delete_required(ids_to_delete: set) -> bool:
logger = prefect.context.get("logger")
n = len(ids_to_delete)
if n > 0:
logger.info(f"Found {n} row(s) to delete.")
res = True
else:
logger.info("No row to delete was found.")
res = False
return res


@task(checkpoint=False)
def delete(ids_to_delete: set):
logger = prefect.context.get("logger")
Expand All @@ -110,7 +59,7 @@ def extract_new_regulations(ids_to_update: set) -> pd.DataFrame:
)

@task(checkpoint=False)
def load_new_regulations(new_regulations: pd.DataFrame):
def update_regulations(new_regulations: pd.DataFrame):
"""Load the output of ``extract_rows_to_update`` task into ``regulations``
table.
Expand Down Expand Up @@ -204,21 +153,48 @@ def load_new_regulations(new_regulations: pd.DataFrame):
)


@task(checkpoint=False)
def load_new_regulations(new_amp: pd.DataFrame):
"""Load the output of ``extract_new_regulations`` task into ``regulations_cacem``
table.
Args:
new_amp (pd.DataFrame): output of ``extract_new_regulations`` task.
"""
load(
new_amp,
table_name="regulations_cacem",
schema="public",
db_name="monitorenv_remote",
logger=prefect.context.get("logger"),
how="append",
table_id_column="id",
df_id_column="id",
)


with Flow("Regulations") as flow:
local_hashes = extract_local_hashes()
remote_hashes = extract_remote_hashes()
hashes = merge_hashes(local_hashes, remote_hashes)
outer_hashes = merge_hashes(local_hashes, remote_hashes)
inner_merged = merge_hashes(local_hashes, remote_hashes, "inner")

ids_to_delete = select_ids_to_delete(hashes)
ids_to_delete = select_ids_to_delete(outer_hashes)
cond_delete = delete_required(ids_to_delete)
with case(cond_delete, True):
delete(ids_to_delete)

ids_to_update = select_ids_to_update(hashes)
ids_to_update = select_ids_to_update(inner_merged)
cond_update = update_required(ids_to_update)
with case(cond_update, True):
new_regulations = extract_new_regulations(ids_to_update)
update_regulations(new_regulations)

ids_to_insert = select_ids_to_insert(outer_hashes)
cond_insert = insert_required(ids_to_insert)
with case(cond_insert, True):
new_regulations = extract_new_regulations(ids_to_insert)
load_new_regulations(new_regulations)


flow.file_name = Path(__file__).name
67 changes: 67 additions & 0 deletions datascience/src/pipeline/shared_tasks/update_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@

import pandas as pd
from prefect import task
import prefect


@task(checkpoint=False)
def update_required(ids_to_update: set) -> bool:
logger = prefect.context.get("logger")
n = len(ids_to_update)
if n > 0:
logger.info(f"Found {n} row(s) to update.")
res = True
else:
logger.info("No row update was found.")
res = False
return res

@task(checkpoint=False)
def merge_hashes(
local_hashes: pd.DataFrame, remote_hashes: pd.DataFrame, how: str = "outer"
) -> pd.DataFrame:
return pd.merge(local_hashes, remote_hashes, on="id", how=how)

@task(checkpoint=False)
def select_ids_to_update(hashes: pd.DataFrame) -> set:
ids_to_update = set(
hashes.loc[
(hashes.cacem_row_hash.notnull())
& (hashes.cacem_row_hash != hashes.monitorenv_row_hash),
"id",
]
)

return ids_to_update

@task(checkpoint=False)
def select_ids_to_delete(hashes: pd.DataFrame) -> set:
return set(hashes.loc[hashes.cacem_row_hash.isna(), "id"])

@task(checkpoint=False)
def select_ids_to_insert(hashes: pd.DataFrame) -> set:
return set(hashes.loc[hashes.monitorenv_row_hash.isna(), "id"])

@task(checkpoint=False)
def insert_required(ids_to_insert: set) -> bool:
logger = prefect.context.get("logger")
n = len(ids_to_insert)
if n > 0:
logger.info(f"Found {n} row(s) to add.")
res = True
else:
logger.info("No row to add was found.")
res = False
return res

@prefect.task(checkpoint=False)
def delete_required(ids_to_delete: set) -> bool:
logger = prefect.context.get("logger")
n = len(ids_to_delete)
if n > 0:
logger.info(f"Found {n} row(s) to delete.")
res = True
else:
logger.info("No row to delete was found.")
res = False
return res
Loading

0 comments on commit cef4c59

Please sign in to comment.