From 07f8686bf6680eb12f518ae02cc63687d0b03cef Mon Sep 17 00:00:00 2001 From: Maxime Perrault Date: Mon, 2 Dec 2024 16:22:44 +0100 Subject: [PATCH] tech: update amp flow to use update instead of delete all --- datascience/src/pipeline/flows/amp_cacem.py | 149 +++++++++++------- datascience/src/pipeline/flows/regulations.py | 90 ++++------- .../pipeline/shared_tasks/update_queries.py | 67 ++++++++ .../test_flows/test_amp_cacem.py | 130 ++++++++------- .../test_flows/test_regulations.py | 87 +++------- .../test_shared_tasks/test_update_queries.py | 50 ++++++ 6 files changed, 338 insertions(+), 235 deletions(-) create mode 100644 datascience/src/pipeline/shared_tasks/update_queries.py create mode 100644 datascience/tests/test_pipeline/test_shared_tasks/test_update_queries.py diff --git a/datascience/src/pipeline/flows/amp_cacem.py b/datascience/src/pipeline/flows/amp_cacem.py index 22fc5770f0..e26f1daffb 100644 --- a/datascience/src/pipeline/flows/amp_cacem.py +++ b/datascience/src/pipeline/flows/amp_cacem.py @@ -3,8 +3,13 @@ import pandas as pd import prefect from prefect import Flow, case, task +from sqlalchemy import text +from src.db_config import create_engine from src.pipeline.generic_tasks import delete_rows, extract, load +from src.pipeline.utils import psql_insert_copy +from src.pipeline.shared_tasks.update_queries import delete_required, insert_required, merge_hashes, select_ids_to_delete, select_ids_to_insert, select_ids_to_update, update_required + @task(checkpoint=False) @@ -34,56 +39,6 @@ def extract_remote_hashes() -> pd.DataFrame: ) -@task(checkpoint=False) -def merge_hashes( - local_hashes: pd.DataFrame, remote_hashes: pd.DataFrame -) -> pd.DataFrame: - return pd.merge(local_hashes, remote_hashes, on="id", how="outer") - - -@task(checkpoint=False) -def select_ids_to_upsert(hashes: pd.DataFrame) -> set: - ids_to_upsert = set( - hashes.loc[ - (hashes.cacem_row_hash.notnull()) - & (hashes.cacem_row_hash != hashes.monitorenv_row_hash), - "id", - ] - ) - - return ids_to_upsert - - -@task(checkpoint=False) -def select_ids_to_delete(hashes: pd.DataFrame) -> set: - return set(hashes.loc[hashes.cacem_row_hash.isna(), "id"]) - - -@task(checkpoint=False) -def upsert_required(ids_to_upsert: set) -> bool: - logger = prefect.context.get("logger") - n = len(ids_to_upsert) - if n > 0: - logger.info(f"Found {n} row(s) to add or upsert.") - res = True - else: - logger.info("No row to add or upsert was found.") - res = False - return res - - -@task(checkpoint=False) -def delete_required(ids_to_delete: set) -> bool: - logger = prefect.context.get("logger") - n = len(ids_to_delete) - if n > 0: - logger.info(f"Found {n} row(s) to delete.") - res = True - else: - logger.info("No row to delete was found.") - res = False - return res - @task(checkpoint=False) def delete(ids_to_delete: set): @@ -107,8 +62,75 @@ def extract_new_amp(ids_to_upsert: set) -> pd.DataFrame: ) + @task(checkpoint=False) -def load_new_amp(new_amp: pd.DataFrame): +def update_amps(new_amps: pd.DataFrame): + """Load the output of ``extract_rows_to_update`` task into ``amp`` + table. + + Args: + new_regulations (pd.DataFrame): output of ``extract_rows_to_update`` task. + """ + e = create_engine("monitorenv_remote") + logger = prefect.context.get("logger") + + with e.begin() as connection: + logger.info("Creating temporary table") + connection.execute( + text( + """CREATE TEMP TABLE tmp_amp_cacem( + id serial, + geom public.geometry(MultiPolygon,4326), + mpa_oriname text, + des_desigfr text, + row_hash text, + mpa_type text, + ref_reg text, + url_legicem text) + ON COMMIT DROP;""" + ) + ) + + columns_to_load = [ + "id", + "geom", + "mpa_oriname", + "des_desigfr", + "row_hash", + "mpa_type", + "ref_reg", + "url_legicem" + ] + + logger.info("Loading to temporary table") + + new_amps[columns_to_load].to_sql( + "tmp_amp_cacem", + connection, + if_exists="append", + index=False, + method=psql_insert_copy, + ) + + logger.info(f"Updating amp_cacem from temporary table {len(new_amps)}") + connection.execute( + text( + """UPDATE amp_cacem amp + SET geom = tmp.geom, + mpa_oriname = tmp.mpa_oriname, + des_desigfr = tmp.des_desigfr, + row_hash = tmp.row_hash, + mpa_type = tmp.mpa_type, + ref_reg = tmp.ref_reg, + url_legicem = tmp.url_legicem + FROM tmp_amp_cacem tmp + where amp.id = tmp.id; + """ + ) + ) + +@task(checkpoint=False) +def load_new_amps(new_amp: pd.DataFrame): """Load the output of ``extract_rows_to_upsert`` task into ``amp`` table. @@ -121,27 +143,34 @@ def load_new_amp(new_amp: pd.DataFrame): schema="public", db_name="monitorenv_remote", logger=prefect.context.get("logger"), - how="upsert", + how="append", table_id_column="id", df_id_column="id", ) with Flow("import amp cacem") as flow: - local_hashes = extract_local_hashes() remote_hashes = extract_remote_hashes() - hashes = merge_hashes(local_hashes, remote_hashes) + outer_hashes = merge_hashes(local_hashes, remote_hashes) + inner_merged = merge_hashes(local_hashes, remote_hashes, "inner") - ids_to_delete = select_ids_to_delete(hashes) + ids_to_delete = select_ids_to_delete(outer_hashes) cond_delete = delete_required(ids_to_delete) with case(cond_delete, True): delete(ids_to_delete) - ids_to_upsert = select_ids_to_upsert(hashes) - cond_upsert = upsert_required(ids_to_upsert) - with case(cond_upsert, True): - new_amp = extract_new_amp(ids_to_upsert) - load_new_amp(new_amp) + ids_to_update = select_ids_to_update(inner_merged) + cond_update = update_required(ids_to_update) + with case(cond_update, True): + new_regulations = extract_new_amp(ids_to_update) + update_amps(new_regulations) + + ids_to_insert = select_ids_to_insert(outer_hashes) + cond_insert = insert_required(ids_to_insert) + with case(cond_insert, True): + new_regulations = extract_new_amp(ids_to_insert) + load_new_amps(new_regulations) + flow.file_name = Path(__file__).name diff --git a/datascience/src/pipeline/flows/regulations.py b/datascience/src/pipeline/flows/regulations.py index a439b9232e..c6a746ffd7 100644 --- a/datascience/src/pipeline/flows/regulations.py +++ b/datascience/src/pipeline/flows/regulations.py @@ -4,8 +4,9 @@ from prefect import Flow, case, task from sqlalchemy import text from src.db_config import create_engine -from src.pipeline.generic_tasks import delete_rows, extract +from src.pipeline.generic_tasks import delete_rows, extract, load from src.pipeline.processing import prepare_df_for_loading +from src.pipeline.shared_tasks.update_queries import delete_required, insert_required, merge_hashes, select_ids_to_delete, select_ids_to_insert, select_ids_to_update, update_required from src.pipeline.utils import psql_insert_copy from src.read_query import read_query @@ -36,58 +37,6 @@ def extract_remote_hashes() -> pd.DataFrame: query_filepath="monitorenv/regulations_hashes.sql", ) - -@task(checkpoint=False) -def merge_hashes( - local_hashes: pd.DataFrame, remote_hashes: pd.DataFrame -) -> pd.DataFrame: - return pd.merge(local_hashes, remote_hashes, on="id", how="outer") - - -@task(checkpoint=False) -def select_ids_to_update(hashes: pd.DataFrame) -> set: - ids_to_update = set( - hashes.loc[ - (hashes.cacem_row_hash.notnull()) - & (hashes.cacem_row_hash != hashes.monitorenv_row_hash), - "id", - ] - ) - - return ids_to_update - - -@task(checkpoint=False) -def select_ids_to_delete(hashes: pd.DataFrame) -> set: - return set(hashes.loc[hashes.cacem_row_hash.isna(), "id"]) - - -@task(checkpoint=False) -def update_required(ids_to_update: set) -> bool: - logger = prefect.context.get("logger") - n = len(ids_to_update) - if n > 0: - logger.info(f"Found {n} row(s) to add or update.") - res = True - else: - logger.info("No row to add or update was found.") - res = False - return res - - -@task(checkpoint=False) -def delete_required(ids_to_delete: set) -> bool: - logger = prefect.context.get("logger") - n = len(ids_to_delete) - if n > 0: - logger.info(f"Found {n} row(s) to delete.") - res = True - else: - logger.info("No row to delete was found.") - res = False - return res - - @task(checkpoint=False) def delete(ids_to_delete: set): logger = prefect.context.get("logger") @@ -110,7 +59,7 @@ def extract_new_regulations(ids_to_update: set) -> pd.DataFrame: ) @task(checkpoint=False) -def load_new_regulations(new_regulations: pd.DataFrame): +def update_regulations(new_regulations: pd.DataFrame): """Load the output of ``extract_rows_to_update`` task into ``regulations`` table. @@ -204,21 +153,48 @@ def load_new_regulations(new_regulations: pd.DataFrame): ) +@task(checkpoint=False) +def load_new_regulations(new_amp: pd.DataFrame): + """Load the output of ``extract_new_regulations`` task into ``regulations_cacem`` + table. + + Args: + new_amp (pd.DataFrame): output of ``extract_new_regulations`` task. + """ + load( + new_amp, + table_name="regulations_cacem", + schema="public", + db_name="monitorenv_remote", + logger=prefect.context.get("logger"), + how="append", + table_id_column="id", + df_id_column="id", + ) + with Flow("Regulations") as flow: local_hashes = extract_local_hashes() remote_hashes = extract_remote_hashes() - hashes = merge_hashes(local_hashes, remote_hashes) + outer_hashes = merge_hashes(local_hashes, remote_hashes) + inner_merged = merge_hashes(local_hashes, remote_hashes, "inner") - ids_to_delete = select_ids_to_delete(hashes) + ids_to_delete = select_ids_to_delete(outer_hashes) cond_delete = delete_required(ids_to_delete) with case(cond_delete, True): delete(ids_to_delete) - ids_to_update = select_ids_to_update(hashes) + ids_to_update = select_ids_to_update(inner_merged) cond_update = update_required(ids_to_update) with case(cond_update, True): new_regulations = extract_new_regulations(ids_to_update) + update_regulations(new_regulations) + + ids_to_insert = select_ids_to_insert(outer_hashes) + cond_insert = insert_required(ids_to_insert) + with case(cond_insert, True): + new_regulations = extract_new_regulations(ids_to_insert) load_new_regulations(new_regulations) + flow.file_name = Path(__file__).name diff --git a/datascience/src/pipeline/shared_tasks/update_queries.py b/datascience/src/pipeline/shared_tasks/update_queries.py new file mode 100644 index 0000000000..087435e7a1 --- /dev/null +++ b/datascience/src/pipeline/shared_tasks/update_queries.py @@ -0,0 +1,67 @@ + +import pandas as pd +from prefect import task +import prefect + + +@task(checkpoint=False) +def update_required(ids_to_update: set) -> bool: + logger = prefect.context.get("logger") + n = len(ids_to_update) + if n > 0: + logger.info(f"Found {n} row(s) to update.") + res = True + else: + logger.info("No row update was found.") + res = False + return res + +@task(checkpoint=False) +def merge_hashes( + local_hashes: pd.DataFrame, remote_hashes: pd.DataFrame, how: str = "outer" +) -> pd.DataFrame: + return pd.merge(local_hashes, remote_hashes, on="id", how=how) + +@task(checkpoint=False) +def select_ids_to_update(hashes: pd.DataFrame) -> set: + ids_to_update = set( + hashes.loc[ + (hashes.cacem_row_hash.notnull()) + & (hashes.cacem_row_hash != hashes.monitorenv_row_hash), + "id", + ] + ) + + return ids_to_update + +@task(checkpoint=False) +def select_ids_to_delete(hashes: pd.DataFrame) -> set: + return set(hashes.loc[hashes.cacem_row_hash.isna(), "id"]) + +@task(checkpoint=False) +def select_ids_to_insert(hashes: pd.DataFrame) -> set: + return set(hashes.loc[hashes.monitorenv_row_hash.isna(), "id"]) + +@task(checkpoint=False) +def insert_required(ids_to_insert: set) -> bool: + logger = prefect.context.get("logger") + n = len(ids_to_insert) + if n > 0: + logger.info(f"Found {n} row(s) to add.") + res = True + else: + logger.info("No row to add was found.") + res = False + return res + +@prefect.task(checkpoint=False) +def delete_required(ids_to_delete: set) -> bool: + logger = prefect.context.get("logger") + n = len(ids_to_delete) + if n > 0: + logger.info(f"Found {n} row(s) to delete.") + res = True + else: + logger.info("No row to delete was found.") + res = False + return res diff --git a/datascience/tests/test_pipeline/test_flows/test_amp_cacem.py b/datascience/tests/test_pipeline/test_flows/test_amp_cacem.py index 3dcc780de8..6ab93f433f 100644 --- a/datascience/tests/test_pipeline/test_flows/test_amp_cacem.py +++ b/datascience/tests/test_pipeline/test_flows/test_amp_cacem.py @@ -1,59 +1,68 @@ import pandas as pd import pytest -from src.pipeline.flows.amp_cacem import ( - load_new_amp, - merge_hashes, - select_ids_to_delete, - select_ids_to_upsert, -) +from src.pipeline.flows.amp_cacem import load_new_amps, update_amps from src.read_query import read_query @pytest.fixture -def local_hashes() -> pd.DataFrame: +def old_amp() -> pd.DataFrame: return pd.DataFrame( { - "id": [1, 2, 3, 4, 6], - "cacem_row_hash": [ - "cacem_row_hash_1", - "cacem_row_hash_2", - "cacem_row_hash_3", - "cacem_row_hash_4_new", - "cacem_row_hash_6", + "id": [1, 2, 3, 4], + "geom": [ + "0106000020E610000001000000010300000001000000040000001EA36CE84A6F04C028FCC" + "F619D7F47407B5A4C4F4F6904C06878344D997F4740906370C20E6A04C050111641647F47" + "401EA36CE84A6F04C028FCCF619D7F4740", + "0106000020E61000000100000001030000000100000004000000508B8D44B1B304C014238" + "1B3F47F4740A374D56D789004C0C0F2BF049B7F474033F02B2558B104C0CCA0D40BEE7E47" + "40508B8D44B1B304C0142381B3F47F4740", + "0106000020E61000000100000001030000000100000004000000D2204A8709EBE33F541AC" + "4E69B024940B8BC1FBE94F2E33F387D124AAF02494021642107D81FE43F387D124AAF0249" + "40D2204A8709EBE33F541AC4E69B024940", + "0106000020E61000000100000001030000000100000004000000F57994631533F2BFE2B98" + "CD5455446407A715E737969F3BFEAD7CEDEB655464036ED5A29A137F4BF97F69352CC3446" + "40F57994631533F2BFE2B98CD545544640", ], - } - ) - - -@pytest.fixture -def remote_hashes() -> pd.DataFrame: - return pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "monitorenv_row_hash": [ + "mpa_oriname": [ + "Calanques - aire d'adhésion", + "dunes, forêt et marais d'Olonne", + "dunes, forêt et marais d'Olonne'", + "estuaire de la Bidassoa et baie de Fontarabie", + ], + "des_desigfr": [ + "Parc national (aire d'adhésion)", + "Zone de protection spéciale (N2000, DO)", + "Zone spéciale de conservation (N2000, DHFF)", + "Zone de protection spéciale (N2000, DO)", + ], + "mpa_type": [ + "Parc national", + "Natura 2000", + "Parc naturel marin", + "Réserve naturelle" + ], + "ref_reg": [ + "arrêté 1", + "arrêté 2", + "arrêté 3", + "arrêté 4", + ], + "url_legicem": [ + "http://dummy_url_1", + "http://dummy_url_2", + "http://dummy_url_3", + "http://dummy_url_4", + ], + "row_hash": [ "cacem_row_hash_1", "cacem_row_hash_2", "cacem_row_hash_3", - "cacem_row_hash_4", - "cacem_row_hash_5", + "cacem_row_hash_4_new", ], } ) - -def test_select_ids_to_delete(remote_hashes, local_hashes): - hashes = merge_hashes.run(local_hashes, remote_hashes) - ids_to_delete = select_ids_to_delete.run(hashes) - assert ids_to_delete == {5} - - -def test_select_ids_to_upsert(remote_hashes, local_hashes): - hashes = merge_hashes.run(local_hashes, remote_hashes) - ids_to_upsert = select_ids_to_upsert.run(hashes) - assert ids_to_upsert == {4, 6} - - @pytest.fixture def new_amp() -> pd.DataFrame: return pd.DataFrame( @@ -93,15 +102,15 @@ def new_amp() -> pd.DataFrame: ], "ref_reg": [ "arrêté 1", - "arrêté 2", + "arrêté 2_updated", "arrêté 3", "arrêté 4", ], "url_legicem": [ - "http://dummy_url_1", - "http://dummy_url_2", - "http://dummy_url_3", - "http://dummy_url_4", + "https://dummy_url_1", + "https://dummy_url_2", + "https://dummy_url_3", + "https://dummy_url_4", ], "row_hash": [ "cacem_row_hash_1", @@ -113,16 +122,29 @@ def new_amp() -> pd.DataFrame: ) -def test_load_new_amp(new_amp): - load_new_amp.run(new_amp) - loaded_amp = read_query( +def test_load_new_amps(old_amp): + load_new_amps.run(old_amp) + loaded_amps = read_query( "monitorenv_remote", - "SELECT id, geom, " - "mpa_oriname, des_desigfr, " - "mpa_type, ref_reg, " - "url_legicem, row_hash " - "FROM amp_cacem " - "ORDER BY id" + """SELECT id, geom, + mpa_oriname, des_desigfr, + mpa_type, ref_reg, + url_legicem, row_hash + FROM amp_cacem + ORDER BY id""" ) + pd.testing.assert_frame_equal(loaded_amps, old_amp) + - pd.testing.assert_frame_equal(loaded_amp, new_amp) +def test_update_new_amps(new_amp): + update_amps.run(new_amp) + updated_amps = read_query( + "monitorenv_remote", + """SELECT id, geom, + mpa_oriname, des_desigfr, + mpa_type, ref_reg, + url_legicem, row_hash + FROM amp_cacem + ORDER BY id""" + ) + pd.testing.assert_frame_equal(updated_amps, new_amp) \ No newline at end of file diff --git a/datascience/tests/test_pipeline/test_flows/test_regulations.py b/datascience/tests/test_pipeline/test_flows/test_regulations.py index 7897405282..57595c3601 100644 --- a/datascience/tests/test_pipeline/test_flows/test_regulations.py +++ b/datascience/tests/test_pipeline/test_flows/test_regulations.py @@ -2,58 +2,11 @@ import prefect import pytest -from src.pipeline.flows.regulations import ( - load_new_regulations, - merge_hashes, - select_ids_to_delete, - select_ids_to_update, -) -from src.pipeline.generic_tasks import load -from src.read_query import read_query - - -@pytest.fixture -def local_hashes() -> pd.DataFrame: - return pd.DataFrame( - { - "id": [1, 2, 3, 4, 6], - "cacem_row_hash": [ - "cacem_row_hash_1", - "cacem_row_hash_2", - "cacem_row_hash_3", - "cacem_row_hash_4_new", - "cacem_row_hash_6", - ], - } - ) - - -@pytest.fixture -def remote_hashes() -> pd.DataFrame: - return pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "monitorenv_row_hash": [ - "cacem_row_hash_1", - "cacem_row_hash_2", - "cacem_row_hash_3", - "cacem_row_hash_4", - "cacem_row_hash_5", - ], - } - ) - - -def test_select_ids_to_delete(remote_hashes, local_hashes): - hashes = merge_hashes.run(local_hashes, remote_hashes) - ids_to_delete = select_ids_to_delete.run(hashes) - assert ids_to_delete == {5} - -def test_select_ids_to_upsert(remote_hashes, local_hashes): - hashes = merge_hashes.run(local_hashes, remote_hashes) - ids_to_upsert = select_ids_to_update.run(hashes) - assert ids_to_upsert == {4, 6} +from src.pipeline.flows.regulations import load_new_regulations, update_regulations +from src.pipeline.generic_tasks import delete_rows, load +from src.pipeline.shared_tasks.update_queries import merge_hashes, select_ids_to_delete, select_ids_to_insert, select_ids_to_update +from src.read_query import read_query import pandas as pd import pytest @@ -141,19 +94,25 @@ def old_regulations() -> pd.DataFrame: types=["Arrêté préfectoral", "Décret", "Arrêté inter-préfectoral", None] ) -def test_load_new_regulations(new_regulations, old_regulations): - load( - old_regulations, - table_name="regulations_cacem", - schema="public", - db_name="monitorenv_remote", - logger=prefect.context.get("logger"), - how="upsert", - table_id_column="id", - df_id_column="id", + +def test_load_new_regulations(old_regulations): + load_new_regulations.run(old_regulations) + loaded_regulations = read_query( + "monitorenv_remote", + """SELECT + id, geom, entity_name, layer_name, facade, + ref_reg, url, row_hash, edition, editeur, + source, observation, thematique, date, + duree_validite, temporalite, type + FROM public.regulations_cacem + ORDER BY id""" ) - load_new_regulations.run(new_regulations) - loaded_new_regulations = read_query( + pd.testing.assert_frame_equal(loaded_regulations, old_regulations) + + +def test_update_new_regulations(new_regulations): + update_regulations.run(new_regulations) + updated_regulations = read_query( "monitorenv_remote", """SELECT id, geom, entity_name, layer_name, facade, @@ -163,4 +122,4 @@ def test_load_new_regulations(new_regulations, old_regulations): FROM public.regulations_cacem ORDER BY id""" ) - pd.testing.assert_frame_equal(loaded_new_regulations, new_regulations) + pd.testing.assert_frame_equal(updated_regulations, new_regulations) diff --git a/datascience/tests/test_pipeline/test_shared_tasks/test_update_queries.py b/datascience/tests/test_pipeline/test_shared_tasks/test_update_queries.py new file mode 100644 index 0000000000..27f0dd1294 --- /dev/null +++ b/datascience/tests/test_pipeline/test_shared_tasks/test_update_queries.py @@ -0,0 +1,50 @@ + +import pandas as pd +import pytest +from src.pipeline.shared_tasks.update_queries import merge_hashes, select_ids_to_delete, select_ids_to_insert, select_ids_to_update + +@pytest.fixture +def local_hashes() -> pd.DataFrame: + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 6], + "cacem_row_hash": [ + "cacem_row_hash_1", + "cacem_row_hash_2", + "cacem_row_hash_3", + "cacem_row_hash_4_new", + "cacem_row_hash_6", + ], + } + ) + +@pytest.fixture +def remote_hashes() -> pd.DataFrame: + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "monitorenv_row_hash": [ + "cacem_row_hash_1", + "cacem_row_hash_2", + "cacem_row_hash_3", + "cacem_row_hash_4", + "cacem_row_hash_5", + ], + } + ) + +def test_select_ids_to_delete(remote_hashes, local_hashes): + hashes = merge_hashes.run(local_hashes, remote_hashes) + ids_to_delete = select_ids_to_delete.run(hashes) + assert ids_to_delete == {5} + + +def test_select_ids_to_update(remote_hashes, local_hashes): + hashes = merge_hashes.run(local_hashes, remote_hashes, "inner") + ids_to_update = select_ids_to_update.run(hashes) + assert ids_to_update == {4} + +def test_select_ids_to_insert(remote_hashes, local_hashes): + hashes = merge_hashes.run(local_hashes, remote_hashes) + ids_to_insert = select_ids_to_insert.run(hashes) + assert ids_to_insert == {6} \ No newline at end of file