diff --git a/datascience/src/pipeline/flows/regulations.py b/datascience/src/pipeline/flows/regulations.py index f9bf00bc3f..b9bf870cc9 100644 --- a/datascience/src/pipeline/flows/regulations.py +++ b/datascience/src/pipeline/flows/regulations.py @@ -1,10 +1,12 @@ -from pathlib import Path - +from logging import Logger import pandas as pd import prefect +from pathlib import Path from prefect import Flow, case, task - +from sqlalchemy import text +from src.db_config import create_engine from src.pipeline.generic_tasks import delete_rows, extract, load +from src.pipeline.utils import psql_insert_copy @task(checkpoint=False) @@ -107,7 +109,26 @@ def extract_new_regulations(ids_to_update: set) -> pd.DataFrame: ) -@task(checkpoint=False) +# @task(checkpoint=False) +# def load_new_regulations(new_regulations: pd.DataFrame): +# """Load the output of ``extract_rows_to_update`` task into ``regulations`` +# table. + +# Args: +# new_regulations (pd.DataFrame): output of ``extract_rows_to_update`` task. +# """ +# load( +# new_regulations, +# table_name="regulations_cacem", +# schema="public", +# db_name="monitorenv_remote", +# logger=prefect.context.get("logger"), +# how="upsert", +# table_id_column="id", +# df_id_column="id", +# ) + + def load_new_regulations(new_regulations: pd.DataFrame): """Load the output of ``extract_rows_to_update`` task into ``regulations`` table. @@ -115,20 +136,97 @@ def load_new_regulations(new_regulations: pd.DataFrame): Args: new_regulations (pd.DataFrame): output of ``extract_rows_to_update`` task. """ - load( - new_regulations, - table_name="regulations_cacem", - schema="public", - db_name="monitorenv_remote", - logger=prefect.context.get("logger"), - how="upsert", - table_id_column="id", - df_id_column="id", - ) + e = create_engine("monitorenv_remote") + logger = prefect.context.get("logger") + with e.begin() as connection: + logger.info("Creating temporary table") + connection.execute( + text( + """CREATE TEMP TABLE tmp_regulations_cacem( + id serial, + geom public.geometry(MultiPolygon,4326), + entity_name character varying, + url character varying, + layer_name character varying, + facade character varying, + ref_reg character varying, + edition character varying, + editeur character varying, + source character varying, + observation character varying, + thematique character varying, + date character varying, + duree_validite character varying, + date_fin character varying, + temporalite character varying, + type character varying, + ) + ON COMMIT DROP;""" + ) + ) + + # a voir pour les geometries + # new_regulations = prepare_df_for_loading( + # new_regulations, + # logger, + # ) + + columns_to_load = [ + "id", + "geom", + "entity_name", + "url", + "layer_name", + "facade", + "ref_reg", + "edition", + "editeur", + "source", + "observation", + "thematique", + "date", + "duree_validite", + "temporalite", + "type", + ] -with Flow("Regulations") as flow: + logger.info("Loading to temporary table") + + new_regulations[columns_to_load].to_sql( + "tmp_regulations_cacem", + connection, + if_exists="append", + index=False, + method=psql_insert_copy, + ) + + logger.info("Updating regulations_cacem from temporary table") + connection.execute( + text( + """UPDATE public.regulations_cacem reg + SET reg.geom = tmp.geom, + reg.entity_name = tmp.entity_name, + reg.url = tmp.url, + reg.layer_name = tmp.layer_name, + reg.facade = tmp.facade, + reg.ref_reg = tmp.ref_reg, + reg.edition = tmp.edition, + reg.editeur = tmp.editeur, + reg.source = tmp.source, + reg.observation = tmp.observation, + reg.thematique = tmp.thematique, + reg.date = tmp.date, + reg.duree_validite = tmp.duree_validite, + reg.temporalite = tmp.temporalite, + reg.type = tmp.type, + FROM tmp_regulations_cacem tmp + where reg.id = tmp.id;""" + ), + ) + +with Flow("Regulations") as flow: local_hashes = extract_local_hashes() remote_hashes = extract_remote_hashes() hashes = merge_hashes(local_hashes, remote_hashes) diff --git a/datascience/tests/test_pipeline/test_flows/test_regulations.py b/datascience/tests/test_pipeline/test_flows/test_regulations.py new file mode 100644 index 0000000000..17d251d1f8 --- /dev/null +++ b/datascience/tests/test_pipeline/test_flows/test_regulations.py @@ -0,0 +1,128 @@ +import pandas as pd +import pytest + +from src.pipeline.flows.regulations import ( + load_new_regulations, + merge_hashes, + select_ids_to_delete, + select_ids_to_update, +) +from src.read_query import read_query + + +@pytest.fixture +def local_hashes() -> pd.DataFrame: + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 6], + "cacem_row_hash": [ + "cacem_row_hash_1", + "cacem_row_hash_2", + "cacem_row_hash_3", + "cacem_row_hash_4_new", + "cacem_row_hash_6", + ], + } + ) + + +@pytest.fixture +def remote_hashes() -> pd.DataFrame: + return pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "monitorenv_row_hash": [ + "cacem_row_hash_1", + "cacem_row_hash_2", + "cacem_row_hash_3", + "cacem_row_hash_4", + "cacem_row_hash_5", + ], + } + ) + + +def test_select_ids_to_delete(remote_hashes, local_hashes): + hashes = merge_hashes.run(local_hashes, remote_hashes) + ids_to_delete = select_ids_to_delete.run(hashes) + assert ids_to_delete == {5} + + +def test_select_ids_to_upsert(remote_hashes, local_hashes): + hashes = merge_hashes.run(local_hashes, remote_hashes) + ids_to_upsert = select_ids_to_update.run(hashes) + assert ids_to_upsert == {4, 6} + + +@pytest.fixture +def new_regulations() -> pd.DataFrame: + return pd.DataFrame( + { + "id": [1, 2, 3, 4], + "geom": [ + "0106000020E610000001000000010300000001000000040000001EA36CE84A6F04C028FCC" + "F619D7F47407B5A4C4F4F6904C06878344D997F4740906370C20E6A04C050111641647F47" + "401EA36CE84A6F04C028FCCF619D7F4740", + "0106000020E61000000100000001030000000100000004000000508B8D44B1B304C014238" + "1B3F47F4740A374D56D789004C0C0F2BF049B7F474033F02B2558B104C0CCA0D40BEE7E47" + "40508B8D44B1B304C0142381B3F47F4740", + "0106000020E61000000100000001030000000100000004000000D2204A8709EBE33F541AC" + "4E69B024940B8BC1FBE94F2E33F387D124AAF02494021642107D81FE43F387D124AAF0249" + "40D2204A8709EBE33F541AC4E69B024940", + "0106000020E61000000100000001030000000100000004000000F57994631533F2BFE2B98" + "CD5455446407A715E737969F3BFEAD7CEDEB655464036ED5A29A137F4BF97F69352CC3446" + "40F57994631533F2BFE2B98CD545544640", + ], + "entity_name": [ + "Zone 1", + "Zone 2", + "Zone 3'", + "Zone 4", + ], + "layer_name": [ + "Layer 1", + "Layer 2", + "Layer 3", + "Layer 4", + ], + "facade": [ + "NAMO", + "NAMO", + "MED", + "MED" + ], + "ref_reg": [ + "arrêté 1", + "arrêté 2", + "arrêté 3", + "arrêté 4", + ], + "url": [ + "http://dummy_url_1", + "http://dummy_url_2", + "http://dummy_url_3", + "http://dummy_url_4", + ], + "row_hash": [ + "cacem_row_hash_1", + "cacem_row_hash_2", + "cacem_row_hash_3", + "cacem_row_hash_4_new", + ], + } + ) + + +def test_load_new_regulations(new_regulations): + load_new_regulations.run(new_regulations) + load_new_regulations = read_query( + "monitorenv_remote", + "SELECT id, geom, " + "mpa_oriname, des_desigfr, " + "mpa_type, ref_reg, " + "url_legicem, row_hash " + "FROM amp_cacem " + "ORDER BY id" + ) + + pd.testing.assert_frame_equal(load_new_regulations, new_regulations)