diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..39247d4f 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + changed: + - Refactored household endpoints to match new API structure \ No newline at end of file diff --git a/policyengine_api/api.py b/policyengine_api/api.py index 0f32ae56..684bb3f9 100644 --- a/policyengine_api/api.py +++ b/policyengine_api/api.py @@ -13,7 +13,9 @@ # from werkzeug.middleware.profiler import ProfilerMiddleware # Endpoints +from policyengine_api.routes.error_routes import error_bp from policyengine_api.routes.economy_routes import economy_bp +from policyengine_api.routes.household_routes import household_bp from policyengine_api.routes.simulation_analysis_routes import ( simulation_analysis_bp, ) @@ -22,9 +24,6 @@ from .endpoints import ( get_home, - get_household, - post_household, - update_household, get_policy, set_policy, get_policy_search, @@ -56,19 +55,13 @@ CORS(app) +app.register_blueprint(error_bp) + app.route("/", methods=["GET"])(get_home) app.register_blueprint(metadata_bp) -app.route("//household/", methods=["GET"])( - get_household -) - -app.route("//household", methods=["POST"])(post_household) - -app.route("//household/", methods=["PUT"])( - update_household -) +app.register_blueprint(household_bp) app.route("//policy/", methods=["GET"])(get_policy) @@ -94,12 +87,10 @@ ) # Routes for economy microsimulation -app.register_blueprint(economy_bp, url_prefix="//economy") +app.register_blueprint(economy_bp) # Routes for AI analysis of economy microsim runs -app.register_blueprint( - simulation_analysis_bp, url_prefix="//simulation-analysis" -) +app.register_blueprint(simulation_analysis_bp) app.route("//user-policy", methods=["POST"])(set_user_policy) @@ -117,9 +108,7 @@ app.route("/simulations", methods=["GET"])(get_simulations) -app.register_blueprint( - tracer_analysis_bp, url_prefix="//tracer-analysis" -) +app.register_blueprint(tracer_analysis_bp) @app.route("/liveness-check", methods=["GET"]) diff --git a/policyengine_api/endpoints/__init__.py b/policyengine_api/endpoints/__init__.py index 8dc30d42..0b7bd51e 100644 --- a/policyengine_api/endpoints/__init__.py +++ b/policyengine_api/endpoints/__init__.py @@ -1,10 +1,7 @@ from .home import get_home from .household import ( - get_household, - post_household, get_household_under_policy, get_calculate, - update_household, ) from .policy import ( get_policy, diff --git a/policyengine_api/endpoints/household.py b/policyengine_api/endpoints/household.py index d5d551cf..b1331856 100644 --- a/policyengine_api/endpoints/household.py +++ b/policyengine_api/endpoints/household.py @@ -74,174 +74,6 @@ def get_household_year(household): return household_year -@validate_country -def get_household(country_id: str, household_id: str) -> dict: - """Get a household's input data with a given ID. - - Args: - country_id (str): The country ID. - household_id (str): The household ID. - """ - - # Retrieve from the household table - row = database.query( - f"SELECT * FROM household WHERE id = ? AND country_id = ?", - (household_id, country_id), - ).fetchone() - - if row is not None: - household = dict(row) - household["household_json"] = json.loads(household["household_json"]) - return dict( - status="ok", - message=None, - result=household, - ) - else: - response_body = dict( - status="error", - message=f"Household #{household_id} not found.", - ) - return Response( - json.dumps(response_body), - status=404, - mimetype="application/json", - ) - - -@validate_country -def post_household(country_id: str) -> dict: - """Set a household's input data. - - Args: - country_id (str): The country ID. - """ - - payload = request.json - label = payload.get("label") - household_json = payload.get("data") - household_hash = hash_object(household_json) - api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - - try: - database.query( - f"INSERT INTO household (country_id, household_json, household_hash, label, api_version) VALUES (?, ?, ?, ?, ?)", - ( - country_id, - json.dumps(household_json), - household_hash, - label, - api_version, - ), - ) - except sqlalchemy.exc.IntegrityError: - pass - - household_id = database.query( - f"SELECT id FROM household WHERE country_id = ? AND household_hash = ?", - (country_id, household_hash), - ).fetchone()["id"] - - response_body = dict( - status="ok", - message=None, - result=dict( - household_id=household_id, - ), - ) - return Response( - json.dumps(response_body), - status=201, - mimetype="application/json", - ) - - -@validate_country -def update_household(country_id: str, household_id: str) -> Response: - """ - Update a household via UPDATE request - - Args: country_id (str): The country ID - """ - - # Fetch existing household first - try: - row = database.query( - f"SELECT * FROM household WHERE id = ? AND country_id = ?", - (household_id, country_id), - ).fetchone() - - if row is not None: - household = dict(row) - household["household_json"] = json.loads( - household["household_json"] - ) - household["label"] - else: - response_body = dict( - status="error", - message=f"Household #{household_id} not found.", - ) - return Response( - json.dumps(response_body), - status=404, - mimetype="application/json", - ) - except Exception as e: - logging.exception(e) - response_body = dict( - status="error", - message=f"Error fetching household #{household_id} while updating: {e}", - ) - return Response( - json.dumps(response_body), - status=500, - mimetype="application/json", - ) - - payload = request.json - label = payload.get("label") or household["label"] - household_json = payload.get("data") or household["household_json"] - household_hash = hash_object(household_json) - api_version = COUNTRY_PACKAGE_VERSIONS.get(country_id) - - try: - database.query( - f"UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", - ( - json.dumps(household_json), - household_hash, - label, - api_version, - household_id, - ), - ) - except Exception as e: - logging.exception(e) - response_body = dict( - status="error", - message=f"Error fetching household #{household_id} while updating: {e}", - ) - return Response( - json.dumps(response_body), - status=500, - mimetype="application/json", - ) - - response_body = dict( - status="ok", - message=None, - result=dict( - household_id=household_id, - ), - ) - return Response( - json.dumps(response_body), - status=200, - mimetype="application/json", - ) - - @validate_country def get_household_under_policy( country_id: str, household_id: str, policy_id: str diff --git a/policyengine_api/routes/economy_routes.py b/policyengine_api/routes/economy_routes.py index a1e922b4..cb731dde 100644 --- a/policyengine_api/routes/economy_routes.py +++ b/policyengine_api/routes/economy_routes.py @@ -3,7 +3,7 @@ from policyengine_api.utils import get_current_law_policy_id from policyengine_api.utils.payload_validators import validate_country from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS -from flask import request, Response +from flask import request import json economy_bp = Blueprint("economy", __name__) @@ -11,7 +11,10 @@ @validate_country -@economy_bp.route("//over/", methods=["GET"]) +@economy_bp.route( + "//economy//over/", + methods=["GET"], +) def get_economic_impact(country_id, policy_id, baseline_policy_id): policy_id = int(policy_id or get_current_law_policy_id(country_id)) @@ -30,25 +33,14 @@ def get_economic_impact(country_id, policy_id, baseline_policy_id): "version", COUNTRY_PACKAGE_VERSIONS.get(country_id) ) - try: - result = economy_service.get_economic_impact( - country_id, - policy_id, - baseline_policy_id, - region, - dataset, - time_period, - options, - api_version, - ) - return result - except Exception as e: - return Response( - { - "status": "error", - "message": "An error occurred while calculating the economic impact. Details: " - + str(e), - "result": None, - }, - 500, - ) + result = economy_service.get_economic_impact( + country_id, + policy_id, + baseline_policy_id, + region, + dataset, + time_period, + options, + api_version, + ) + return result diff --git a/policyengine_api/routes/error_routes.py b/policyengine_api/routes/error_routes.py new file mode 100644 index 00000000..e9fced1c --- /dev/null +++ b/policyengine_api/routes/error_routes.py @@ -0,0 +1,67 @@ +import json +from flask import Response, Blueprint +from werkzeug.exceptions import ( + HTTPException, +) + +error_bp = Blueprint("error", __name__) + + +@error_bp.app_errorhandler(404) +def response_404(error) -> Response: + """Specific handler for 404 Not Found errors""" + return make_error_response(error, 404) + + +@error_bp.app_errorhandler(400) +def response_400(error) -> Response: + """Specific handler for 400 Bad Request errors""" + return make_error_response(error, 400) + + +@error_bp.app_errorhandler(401) +def response_401(error) -> Response: + """Specific handler for 401 Unauthorized errors""" + return make_error_response(error, 401) + + +@error_bp.app_errorhandler(403) +def response_403(error) -> Response: + """Specific handler for 403 Forbidden errors""" + return make_error_response(error, 403) + + +@error_bp.app_errorhandler(500) +def response_500(error) -> Response: + """Specific handler for 500 Internal Server errors""" + return make_error_response(error, 500) + + +@error_bp.app_errorhandler(HTTPException) +def response_http_exception(error: HTTPException) -> Response: + """Generic handler for HTTPException; should be raised if no specific handler is found""" + return make_error_response(str(error), error.code) + + +@error_bp.app_errorhandler(Exception) +def response_generic_error(error: Exception) -> Response: + """Handler for any unhandled exceptions""" + return make_error_response(str(error), 500) + + +def make_error_response( + error, + status_code: int, +) -> Response: + """Create a generic error response""" + return Response( + json.dumps( + { + "status": "error", + "message": str(error), + "result": None, + } + ), + status_code, + mimetype="application/json", + ) diff --git a/policyengine_api/routes/household_routes.py b/policyengine_api/routes/household_routes.py new file mode 100644 index 00000000..893d6def --- /dev/null +++ b/policyengine_api/routes/household_routes.py @@ -0,0 +1,137 @@ +from flask import Blueprint, Response, request +from werkzeug.exceptions import NotFound, BadRequest +import json + +from policyengine_api.services.household_service import HouseholdService +from policyengine_api.utils.payload_validators import ( + validate_household_payload, + validate_country, +) + + +household_bp = Blueprint("household", __name__) +household_service = HouseholdService() + + +@household_bp.route( + "//household/", methods=["GET"] +) +@validate_country +def get_household(country_id: str, household_id: int) -> Response: + """ + Get a household's input data with a given ID. + + Args: + country_id (str): The country ID. + household_id (int): The household ID. + """ + print(f"Got request for household {household_id} in country {country_id}") + + household: dict | None = household_service.get_household( + country_id, household_id + ) + if household is None: + raise NotFound(f"Household #{household_id} not found.") + else: + return Response( + json.dumps( + { + "status": "ok", + "message": None, + "result": household, + } + ), + status=200, + mimetype="application/json", + ) + + +@household_bp.route("//household", methods=["POST"]) +@validate_country +def post_household(country_id: str) -> Response: + """ + Set a household's input data. + + Args: + country_id (str): The country ID. + """ + + # Validate payload + payload = request.json + is_payload_valid, message = validate_household_payload(payload) + if not is_payload_valid: + raise BadRequest(f"Unable to create new household; details: {message}") + + # The household label appears to be unimplemented at this time, + # thus it should always be 'None' + label: str | None = payload.get("label") + household_json: dict = payload.get("data") + + household_id = household_service.create_household( + country_id, household_json, label + ) + + return Response( + json.dumps( + { + "status": "ok", + "message": None, + "result": { + "household_id": household_id, + }, + } + ), + status=201, + mimetype="application/json", + ) + + +@household_bp.route( + "//household/", methods=["PUT"] +) +@validate_country +def update_household(country_id: str, household_id: int) -> Response: + """ + Update a household's input data. + + Args: + country_id (str): The country ID. + household_id (int): The household ID. + """ + + # Validate payload + payload = request.json + is_payload_valid, message = validate_household_payload(payload) + if not is_payload_valid: + raise BadRequest( + f"Unable to update household #{household_id}; details: {message}" + ) + + # First, attempt to fetch the existing household + label: str | None = payload.get("label") + household_json: dict = payload.get("data") + + household: dict | None = household_service.get_household( + country_id, household_id + ) + if household is None: + raise NotFound(f"Household #{household_id} not found.") + + # Next, update the household + updated_household: dict = household_service.update_household( + country_id, household_id, household_json, label + ) + return Response( + json.dumps( + { + "status": "ok", + "message": None, + "result": { + "household_id": household_id, + "household_json": updated_household["household_json"], + }, + } + ), + status=200, + mimetype="application/json", + ) diff --git a/policyengine_api/routes/metadata_routes.py b/policyengine_api/routes/metadata_routes.py index 583f8e5b..89bdfbea 100644 --- a/policyengine_api/routes/metadata_routes.py +++ b/policyengine_api/routes/metadata_routes.py @@ -1,4 +1,5 @@ -from flask import Blueprint +import json +from flask import Blueprint, Response from policyengine_api.utils.payload_validators import validate_country from policyengine_api.services.metadata_service import MetadataService @@ -9,10 +10,22 @@ @metadata_bp.route("//metadata", methods=["GET"]) @validate_country -def get_metadata(country_id: str) -> dict: +def get_metadata(country_id: str) -> Response: """Get metadata for a country. Args: country_id (str): The country ID. """ - return metadata_service.get_metadata(country_id) + metadata = metadata_service.get_metadata(country_id) + + return Response( + json.dumps( + { + "status": "ok", + "message": None, + "result": metadata, + } + ), + status=200, + mimetype="application/json", + ) diff --git a/policyengine_api/routes/simulation_analysis_routes.py b/policyengine_api/routes/simulation_analysis_routes.py index 5326989a..d86e4b79 100644 --- a/policyengine_api/routes/simulation_analysis_routes.py +++ b/policyengine_api/routes/simulation_analysis_routes.py @@ -1,5 +1,5 @@ from flask import Blueprint, request, Response, stream_with_context -import json +from werkzeug.exceptions import BadRequest from policyengine_api.utils.payload_validators import validate_country from policyengine_api.services.simulation_analysis_service import ( SimulationAnalysisService, @@ -13,7 +13,9 @@ simulation_analysis_service = SimulationAnalysisService() -@simulation_analysis_bp.route("", methods=["POST"]) +@simulation_analysis_bp.route( + "//simulation-analysis", methods=["POST"] +) @validate_country def execute_simulation_analysis(country_id): print("Got POST request for simulation analysis") @@ -24,9 +26,7 @@ def execute_simulation_analysis(country_id): is_payload_valid, message = validate_sim_analysis_payload(payload) if not is_payload_valid: - return Response( - status=400, response=f"Invalid JSON data; details: {message}" - ) + raise BadRequest(f"Invalid JSON data; details: {message}") currency: str = payload.get("currency") selected_version: str = payload.get("selected_version") @@ -41,41 +41,28 @@ def execute_simulation_analysis(country_id): ) audience = payload.get("audience", "") - try: - analysis = simulation_analysis_service.execute_analysis( - country_id, - currency, - selected_version, - time_period, - impact, - policy_label, - policy, - region, - relevant_parameters, - relevant_parameter_baseline_values, - audience, - ) + analysis = simulation_analysis_service.execute_analysis( + country_id, + currency, + selected_version, + time_period, + impact, + policy_label, + policy, + region, + relevant_parameters, + relevant_parameter_baseline_values, + audience, + ) - # Create streaming response - response = Response( - stream_with_context(analysis), - status=200, - ) + # Create streaming response + response = Response( + stream_with_context(analysis), + status=200, + ) - # Set header to prevent buffering on Google App Engine deployment - # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) - response.headers["X-Accel-Buffering"] = "no" + # Set header to prevent buffering on Google App Engine deployment + # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) + response.headers["X-Accel-Buffering"] = "no" - return response - except Exception as e: - return Response( - json.dumps( - { - "status": "error", - "message": "An error occurred while executing the simulation analysis. Details: " - + str(e), - "result": None, - } - ), - status=500, - ) + return response diff --git a/policyengine_api/routes/tracer_analysis_routes.py b/policyengine_api/routes/tracer_analysis_routes.py index d15b0dd4..d4c48aa4 100644 --- a/policyengine_api/routes/tracer_analysis_routes.py +++ b/policyengine_api/routes/tracer_analysis_routes.py @@ -1,4 +1,5 @@ from flask import Blueprint, request, Response, stream_with_context +from werkzeug.exceptions import BadRequest from policyengine_api.utils.payload_validators import ( validate_country, validate_tracer_analysis_payload, @@ -6,13 +7,12 @@ from policyengine_api.services.tracer_analysis_service import ( TracerAnalysisService, ) -import json tracer_analysis_bp = Blueprint("tracer_analysis", __name__) tracer_analysis_service = TracerAnalysisService() -@tracer_analysis_bp.route("", methods=["POST"]) +@tracer_analysis_bp.route("//tracer-analysis", methods=["POST"]) @validate_country def execute_tracer_analysis(country_id): @@ -20,56 +20,27 @@ def execute_tracer_analysis(country_id): is_payload_valid, message = validate_tracer_analysis_payload(payload) if not is_payload_valid: - return Response( - status=400, response=f"Invalid JSON data; details: {message}" - ) + raise BadRequest(f"Invalid JSON data; details: {message}") household_id = payload.get("household_id") policy_id = payload.get("policy_id") variable = payload.get("variable") - try: - # Create streaming response - response = Response( - stream_with_context( - tracer_analysis_service.execute_analysis( - country_id, - household_id, - policy_id, - variable, - ) - ), - status=200, - ) - - # Set header to prevent buffering on Google App Engine deployment - # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) - response.headers["X-Accel-Buffering"] = "no" - - return response - except KeyError as e: - """ - This exception is raised when the tracer can't find a household tracer record - """ - return Response( - json.dumps( - { - "status": "not found", - "message": "No household simulation tracer found", - "result": None, - } - ), - 404, - ) - except Exception as e: - return Response( - json.dumps( - { - "status": "error", - "message": "An error occurred while executing the tracer analysis. Details: " - + str(e), - "result": None, - } - ), - 500, - ) + # Create streaming response + response = Response( + stream_with_context( + tracer_analysis_service.execute_analysis( + country_id, + household_id, + policy_id, + variable, + ) + ), + status=200, + ) + + # Set header to prevent buffering on Google App Engine deployment + # (see https://cloud.google.com/appengine/docs/flexible/how-requests-are-handled?tab=python#x-accel-buffering) + response.headers["X-Accel-Buffering"] = "no" + + return response diff --git a/policyengine_api/services/household_service.py b/policyengine_api/services/household_service.py new file mode 100644 index 00000000..bac1125c --- /dev/null +++ b/policyengine_api/services/household_service.py @@ -0,0 +1,129 @@ +import json +from sqlalchemy.engine.row import LegacyRow + +from policyengine_api.data import database +from policyengine_api.utils import hash_object +from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS + + +class HouseholdService: + + def get_household(self, country_id: str, household_id: int) -> dict | None: + """ + Get a household's input data with a given ID. + + Args: + country_id (str): The country ID. + household_id (int): The household ID. + """ + print("Getting household data") + + try: + row: LegacyRow | None = database.query( + f"SELECT * FROM household WHERE id = ? AND country_id = ?", + (household_id, country_id), + ).fetchone() + + # If row is present, we must JSON.loads the household_json + household = None + if row is not None: + household = dict(row) + if household["household_json"]: + household["household_json"] = json.loads( + household["household_json"] + ) + return household + + except Exception as e: + print( + f"Error fetching household #{household_id}. Details: {str(e)}" + ) + raise e + + def create_household( + self, + country_id: str, + household_json: dict, + label: str | None, + ) -> int: + """ + Create a new household with the given data. + + Args: + country_id (str): The country ID. + household_json (dict): The household data. + household_hash (int): The hash of the household data. + label (str): The label for the household. + api_version (str): The API version. + """ + + print("Creating new household") + + try: + household_hash: str = hash_object(household_json) + api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + + database.query( + f"INSERT INTO household (country_id, household_json, household_hash, label, api_version) VALUES (?, ?, ?, ?, ?)", + ( + country_id, + json.dumps(household_json), + household_hash, + label, + api_version, + ), + ) + + household_id = database.query( + f"SELECT id FROM household WHERE country_id = ? AND household_hash = ?", + (country_id, household_hash), + ).fetchone()["id"] + + return household_id + except Exception as e: + print(f"Error creating household. Details: {str(e)}") + raise e + + def update_household( + self, + country_id: str, + household_id: int, + household_json: dict, + label: str, + ) -> dict: + """ + Update a household with the given data. + + Args: + country_id (str): The country ID. + household_id (int): The household ID. + payload (dict): The data to update the household with. + """ + print("Updating household") + + try: + + household_hash: str = hash_object(household_json) + api_version: str = COUNTRY_PACKAGE_VERSIONS.get(country_id) + + database.query( + f"UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", + ( + json.dumps(household_json), + household_hash, + label, + api_version, + household_id, + ), + ) + + # Fetch the updated JSON back from the table + updated_household: dict = self.get_household( + country_id, household_id + ) + return updated_household + except Exception as e: + print( + f"Error updating household #{household_id}. Details: {str(e)}" + ) + raise e diff --git a/policyengine_api/services/tracer_analysis_service.py b/policyengine_api/services/tracer_analysis_service.py index 162245d2..7cc76d4b 100644 --- a/policyengine_api/services/tracer_analysis_service.py +++ b/policyengine_api/services/tracer_analysis_service.py @@ -5,6 +5,7 @@ import re import anthropic from policyengine_api.services.ai_analysis_service import AIAnalysisService +from werkzeug.exceptions import NotFound class TracerAnalysisService(AIAnalysisService): @@ -80,7 +81,7 @@ def get_tracer( ).fetchone() if row is None: - raise KeyError("No tracer found for this household") + raise NotFound("No household simulation tracer found") tracer_output_list = json.loads(row["tracer_output"]) return tracer_output_list diff --git a/policyengine_api/utils/payload_validators/__init__.py b/policyengine_api/utils/payload_validators/__init__.py index 4bb15c1f..520707a8 100644 --- a/policyengine_api/utils/payload_validators/__init__.py +++ b/policyengine_api/utils/payload_validators/__init__.py @@ -1,3 +1,4 @@ from .validate_sim_analysis_payload import validate_sim_analysis_payload from .validate_tracer_analysis_payload import validate_tracer_analysis_payload from .validate_country import validate_country +from .validate_household_payload import validate_household_payload diff --git a/policyengine_api/utils/payload_validators/validate_country.py b/policyengine_api/utils/payload_validators/validate_country.py index d9c5ee1b..c891f98a 100644 --- a/policyengine_api/utils/payload_validators/validate_country.py +++ b/policyengine_api/utils/payload_validators/validate_country.py @@ -6,25 +6,26 @@ def validate_country(func): - """Validate that a country ID is valid. If not, return a 404 response. + """Validate that a country ID is valid. If not, return a 400 response. Args: country_id (str): The country ID to validate. Returns: - Response(404) if country is not valid, else continues + Response(400) if country is not valid, else continues """ @wraps(func) def validate_country_wrapper( country_id: str, *args, **kwargs ) -> Union[None, Response]: + print("Validating country") if country_id not in COUNTRIES: body = dict( status="error", message=f"Country {country_id} not found. Available countries are: {', '.join(COUNTRIES)}", ) - return Response(json.dumps(body), status=404) + return Response(json.dumps(body), status=400) return func(country_id, *args, **kwargs) return validate_country_wrapper diff --git a/policyengine_api/utils/payload_validators/validate_household_payload.py b/policyengine_api/utils/payload_validators/validate_household_payload.py new file mode 100644 index 00000000..7b4f7d95 --- /dev/null +++ b/policyengine_api/utils/payload_validators/validate_household_payload.py @@ -0,0 +1,31 @@ +import json + + +def validate_household_payload(payload): + """ + Validate the payload for a POST request to set a household's input data. + + Args: + payload (dict): The payload to validate. + + Returns: + tuple[bool, str]: A tuple containing a boolean indicating whether the payload is valid and a message. + """ + # Check that all required keys are present + required_keys = ["data"] + missing_keys = [key for key in required_keys if key not in payload] + if missing_keys: + return False, f"Missing required keys: {missing_keys}" + + # Check that label is either string or None, if present + if "label" in payload: + if payload["label"] is not None and not isinstance( + payload["label"], str + ): + return False, "Label must be a string or None" + + # Check that data is a dictionary + if not isinstance(payload["data"], dict): + return False, "Unable to parse household JSON data" + + return True, None diff --git a/setup.py b/setup.py index 4a8af5f3..d78f2c82 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "rq", "sqlalchemy>=1.4,<2", "streamlit", + "werkzeug", "Flask-Caching>=2,<3", ], extras_require={ diff --git a/tests/api/test_us_create_empty_household.yaml b/tests/api/test_us_create_empty_household.yaml index 04b70090..9c36731a 100644 --- a/tests/api/test_us_create_empty_household.yaml +++ b/tests/api/test_us_create_empty_household.yaml @@ -3,6 +3,7 @@ endpoint: /us/household method: POST data: label: Empty Household + data: {} response: data: status: ok diff --git a/tests/fixtures/household_fixtures.py b/tests/fixtures/household_fixtures.py new file mode 100644 index 00000000..733b6857 --- /dev/null +++ b/tests/fixtures/household_fixtures.py @@ -0,0 +1,39 @@ +import pytest +import json +from unittest.mock import patch + + +SAMPLE_HOUSEHOLD_DATA = { + "data": {"people": {"person1": {"age": 30, "income": 50000}}}, + "label": "Test Household", +} + +SAMPLE_DB_ROW = { + "id": 1, + "country_id": "us", + "household_json": json.dumps(SAMPLE_HOUSEHOLD_DATA["data"]), + "household_hash": "some-hash", + "label": "Test Household", + "api_version": "3.0.0", +} + + +# These will be moved to the correct location once +# testing PR that creates folder structure is merged +@pytest.fixture +def mock_database(): + """Mock the database module.""" + with patch( + "policyengine_api.services.household_service.database" + ) as mock_db: + yield mock_db + + +@pytest.fixture +def mock_hash_object(): + """Mock the hash_object function.""" + with patch( + "policyengine_api.services.household_service.hash_object" + ) as mock: + mock.return_value = "some-hash" + yield mock diff --git a/tests/python/test_errors.py b/tests/python/test_errors.py new file mode 100644 index 00000000..c36d75b2 --- /dev/null +++ b/tests/python/test_errors.py @@ -0,0 +1,120 @@ +import pytest +from flask import Flask +from policyengine_api.routes.error_routes import error_bp +from werkzeug.exceptions import ( + NotFound, + BadRequest, + Unauthorized, + Forbidden, + InternalServerError, +) + + +@pytest.fixture +def app(): + """Create and configure a new app instance for each test.""" + app = Flask(__name__) + app.register_blueprint(error_bp) + return app + + +@pytest.fixture +def client(app): + """Create a test client for the app.""" + return app.test_client() + + +def test_404_handler(app, client): + """Test 404 Not Found error handling""" + + @app.route("/nonexistent") + def nonexistent(): + raise NotFound("Custom not found message") + + response = client.get("/nonexistent") + data = response.get_json() + + assert response.status_code == 404 + assert data["status"] == "error" + assert "Custom not found message" in data["message"] + assert data["result"] is None + + +def test_400_handler(app, client): + """Test 400 Bad Request error handling""" + + @app.route("/bad-request") + def bad_request(): + raise BadRequest("Invalid parameters") + + response = client.get("/bad-request") + data = response.get_json() + + assert response.status_code == 400 + assert data["status"] == "error" + assert "Invalid parameters" in data["message"] + assert data["result"] is None + + +def test_401_handler(app, client): + """Test 401 Unauthorized error handling""" + + @app.route("/unauthorized") + def unauthorized(): + raise Unauthorized("Invalid credentials") + + response = client.get("/unauthorized") + data = response.get_json() + + assert response.status_code == 401 + assert data["status"] == "error" + assert "Invalid credentials" in data["message"] + assert data["result"] is None + + +def test_403_handler(app, client): + """Test 403 Forbidden error handling""" + + @app.route("/forbidden") + def forbidden(): + raise Forbidden("Access denied") + + response = client.get("/forbidden") + data = response.get_json() + + assert response.status_code == 403 + assert data["status"] == "error" + assert "Access denied" in data["message"] + assert data["result"] is None + + +def test_500_handler(app, client): + """Test 500 Internal Server Error handling""" + + @app.route("/server-error") + def server_error(): + raise InternalServerError("Database connection failed") + + response = client.get("/server-error") + data = response.get_json() + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Database connection failed" in data["message"] + assert data["result"] is None + + +def test_generic_exception_handler(app, client): + """Test handling of generic exceptions""" + + @app.route("/generic-error") + def generic_error(): + raise ValueError("Something went wrong") + + response = client.get("/generic-error") + data = response.get_json() + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Something went wrong" in data["message"] + assert data["result"] is None diff --git a/tests/python/test_household.py b/tests/python/test_household.py new file mode 100644 index 00000000..85c9d89d --- /dev/null +++ b/tests/python/test_household.py @@ -0,0 +1,406 @@ +import pytest +import json +from unittest.mock import MagicMock, patch +from sqlalchemy.engine.row import LegacyRow + +from policyengine_api.routes.household_routes import household_bp +from policyengine_api.services.household_service import HouseholdService +from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS + +from tests.fixtures.household_fixtures import ( + SAMPLE_HOUSEHOLD_DATA, + SAMPLE_DB_ROW, + mock_database, + mock_hash_object, +) + + +class TestGetHousehold: + def test_get_existing_household(self, rest_client, mock_database): + """Test getting an existing household.""" + # Mock database response + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: SAMPLE_DB_ROW[x] + mock_row.keys.return_value = SAMPLE_DB_ROW.keys() + mock_database.query().fetchone.return_value = mock_row + + # Make request + response = rest_client.get("/us/household/1") + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + assert ( + data["result"]["household_json"] == SAMPLE_HOUSEHOLD_DATA["data"] + ) + + def test_get_nonexistent_household(self, rest_client, mock_database): + """Test getting a non-existent household.""" + mock_database.query().fetchone.return_value = None + + response = rest_client.get("/us/household/999") + data = json.loads(response.data) + + assert response.status_code == 404 + assert data["status"] == "error" + assert "not found" in data["message"] + + def test_get_household_invalid_id(self, rest_client): + """Test getting a household with invalid ID.""" + response = rest_client.get("/us/household/invalid") + + assert response.status_code == 404 + assert ( + b"The requested URL was not found on the server" in response.data + ) + + +class TestCreateHousehold: + def test_create_household_success( + self, rest_client, mock_database, mock_hash_object + ): + """Test successfully creating a new household.""" + # Mock database responses + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x] + mock_database.query().fetchone.return_value = mock_row + + response = rest_client.post( + "/us/household", + json=SAMPLE_HOUSEHOLD_DATA, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 201 + assert data["status"] == "ok" + assert data["result"]["household_id"] == 1 + + def test_create_household_invalid_payload(self, rest_client): + """Test creating a household with invalid payload.""" + invalid_payload = { + "label": "Test", + # Missing required 'data' field + } + + response = rest_client.post( + "/us/household", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Missing required keys" in response.data + + def test_create_household_invalid_label(self, rest_client): + """Test creating a household with invalid label type.""" + invalid_payload = { + "data": {}, + "label": 123, # Should be string or None + } + + response = rest_client.post( + "/us/household", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Label must be a string or None" in response.data + + +class TestUpdateHousehold: + def test_update_household_success( + self, rest_client, mock_database, mock_hash_object + ): + """Test successfully updating an existing household.""" + # Mock getting existing household + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: SAMPLE_DB_ROW[x] + mock_row.keys.return_value = SAMPLE_DB_ROW.keys() + mock_database.query().fetchone.return_value = mock_row + + updated_household = { + "people": {"person1": {"age": 31, "income": 55000}} + } + + updated_data = { + "data": updated_household, + "label": SAMPLE_HOUSEHOLD_DATA["label"], + } + + response = rest_client.put( + "/us/household/1", + json=updated_data, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["result"]["household_id"] == 1 + # assert data["result"]["household_json"] == updated_data["data"] + mock_database.query.assert_any_call( + "UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", + ( + json.dumps(updated_household), + "some-hash", + SAMPLE_HOUSEHOLD_DATA["label"], + COUNTRY_PACKAGE_VERSIONS.get("us"), + 1, + ), + ) + + def test_update_nonexistent_household(self, rest_client, mock_database): + """Test updating a non-existent household.""" + mock_database.query().fetchone.return_value = None + + response = rest_client.put( + "/us/household/999", + json=SAMPLE_HOUSEHOLD_DATA, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 404 + assert data["status"] == "error" + assert "not found" in data["message"] + + def test_update_household_invalid_payload(self, rest_client): + """Test updating a household with invalid payload.""" + invalid_payload = { + "label": "Test", + # Missing required 'data' field + } + + response = rest_client.put( + "/us/household/1", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Missing required keys" in response.data + + +# Service level tests +class TestHouseholdService: + def test_get_household(self, mock_database): + """Test HouseholdService.get_household method.""" + service = HouseholdService() + + # Mock database response + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: SAMPLE_DB_ROW[x] + mock_row.keys.return_value = SAMPLE_DB_ROW.keys() + mock_database.query().fetchone.return_value = mock_row + + result = service.get_household("us", 1) + + assert result is not None + assert result["household_json"] == SAMPLE_HOUSEHOLD_DATA["data"] + + def test_create_household(self, mock_database, mock_hash_object): + """Test HouseholdService.create_household method.""" + service = HouseholdService() + + # Mock database response for the ID query + mock_row = MagicMock(spec=LegacyRow) + mock_row.__getitem__.side_effect = lambda x: {"id": 1}[x] + mock_database.query().fetchone.return_value = mock_row + + household_id = service.create_household( + "us", SAMPLE_HOUSEHOLD_DATA["data"], SAMPLE_HOUSEHOLD_DATA["label"] + ) + + assert household_id == 1 + mock_database.query.assert_called() + + def test_update_household(self, mock_database, mock_hash_object): + """Test HouseholdService.update_household method.""" + service = HouseholdService() + mock_database.query().fetchone.return_value = SAMPLE_DB_ROW + + service.update_household( + "us", + 1, + SAMPLE_HOUSEHOLD_DATA["data"], + SAMPLE_HOUSEHOLD_DATA["label"], + ) + + assert mock_hash_object.called + mock_database.query.assert_any_call( + "UPDATE household SET household_json = ?, household_hash = ?, label = ?, api_version = ? WHERE id = ?", + ( + json.dumps(SAMPLE_HOUSEHOLD_DATA["data"]), + "some-hash", + SAMPLE_HOUSEHOLD_DATA["label"], + COUNTRY_PACKAGE_VERSIONS.get("us"), + 1, + ), + ) + + +class TestHouseholdRouteValidation: + """Test validation and error handling in household routes.""" + + @pytest.mark.parametrize( + "invalid_payload", + [ + {}, # Empty payload + {"label": "Test"}, # Missing data field + {"data": None}, # None data + {"data": "not_a_dict"}, # Non-dict data + {"data": {}, "label": 123}, # Invalid label type + ], + ) + def test_post_household_invalid_payload( + self, rest_client, invalid_payload + ): + """Test POST endpoint with various invalid payloads.""" + response = rest_client.post( + "/us/household", + json=invalid_payload, + content_type="application/json", + ) + + assert response.status_code == 400 + assert b"Unable to create new household" in response.data + + @pytest.mark.parametrize( + "invalid_id", + [ + "abc", # Non-numeric + "1.5", # Float + ], + ) + def test_get_household_invalid_id(self, rest_client, invalid_id): + """Test GET endpoint with invalid household IDs.""" + response = rest_client.get(f"/us/household/{invalid_id}") + + # Default Werkzeug validation returns 404, not 400 + assert response.status_code == 404 + assert ( + b"The requested URL was not found on the server" in response.data + ) + + @pytest.mark.parametrize( + "country_id", + [ + "123", # Numeric + "us!!", # Special characters + "zz", # Non-ISO + "a" * 100, # Too long + ], + ) + def test_invalid_country_id(self, rest_client, country_id): + """Test endpoints with invalid country IDs.""" + # Test GET + get_response = rest_client.get(f"/{country_id}/household/1") + assert get_response.status_code == 400 + + # Test POST + post_response = rest_client.post( + f"/{country_id}/household", + json={"data": {}}, + content_type="application/json", + ) + assert post_response.status_code == 400 + + # Test PUT + put_response = rest_client.put( + f"/{country_id}/household/1", + json={"data": {}}, + content_type="application/json", + ) + assert put_response.status_code == 400 + + +class TestHouseholdRouteServiceErrors: + """Test handling of service-level errors in routes.""" + + @patch( + "policyengine_api.services.household_service.HouseholdService.get_household" + ) + def test_get_household_service_error(self, mock_get, rest_client): + """Test GET endpoint when service raises an error.""" + mock_get.side_effect = Exception("Database connection failed") + + response = rest_client.get("/us/household/1") + data = json.loads(response.data) + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Database connection failed" in data["message"] + + @patch( + "policyengine_api.services.household_service.HouseholdService.create_household" + ) + def test_post_household_service_error(self, mock_create, rest_client): + """Test POST endpoint when service raises an error.""" + mock_create.side_effect = Exception("Failed to create household") + + response = rest_client.post( + "/us/household", + json={"data": {"valid": "payload"}}, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Failed to create household" in data["message"] + + @patch( + "policyengine_api.services.household_service.HouseholdService.update_household" + ) + def test_put_household_service_error(self, mock_update, rest_client): + """Test PUT endpoint when service raises an error.""" + mock_update.side_effect = Exception("Failed to update household") + + # First mock the get_household call that checks existence + with patch( + "policyengine_api.services.household_service.HouseholdService.get_household" + ) as mock_get: + mock_get.return_value = {"id": 1} # Simulate existing household + + response = rest_client.put( + "/us/household/1", + json={"data": {"valid": "payload"}}, + content_type="application/json", + ) + data = json.loads(response.data) + + assert response.status_code == 500 + assert data["status"] == "error" + assert "Failed to update household" in data["message"] + + def test_missing_json_body(self, rest_client): + """Test endpoints when JSON body is missing.""" + # Test POST without JSON + post_response = rest_client.post("/us/household") + # Actually intercepted by server, which responds with 415, + # before we can even return a 400 + assert post_response.status_code in [400, 415] + + # Test PUT without JSON + put_response = rest_client.put("/us/household/1") + assert put_response.status_code in [400, 415] + + def test_malformed_json_body(self, rest_client): + """Test endpoints with malformed JSON body.""" + # Test POST with malformed JSON + post_response = rest_client.post( + "/us/household", + data="invalid json{", + content_type="application/json", + ) + assert post_response.status_code == 400 + + # Test PUT with malformed JSON + put_response = rest_client.put( + "/us/household/1", + data="invalid json{", + content_type="application/json", + ) + assert put_response.status_code == 400 diff --git a/tests/python/test_policy.py b/tests/python/test_policy.py index 9d674344..e79fa892 100644 --- a/tests/python/test_policy.py +++ b/tests/python/test_policy.py @@ -46,7 +46,7 @@ def test_create_nonunique_policy(self, rest_client): def test_create_policy_invalid_country(self, rest_client): res = rest_client.post("/au/policy", json=self.test_policy) - assert res.status_code == 404 + assert res.status_code == 400 class TestPolicySearch: diff --git a/tests/python/test_simulation_analysis.py b/tests/python/test_simulation_analysis.py index 9a3292c6..29b7dc5c 100644 --- a/tests/python/test_simulation_analysis.py +++ b/tests/python/test_simulation_analysis.py @@ -14,61 +14,56 @@ test_service = SimulationAnalysisService() -@pytest.fixture -def app(): - app = Flask(__name__) - app.config["TESTING"] = True - return app +def test_execute_simulation_analysis_existing_analysis(rest_client): + with patch( + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" + ) as mock_get_existing: + mock_get_existing.return_value = (s for s in ["Existing analysis"]) -def test_execute_simulation_analysis_existing_analysis(app, rest_client): + response = rest_client.post("/us/simulation-analysis", json=test_json) - with app.test_request_context(json=test_json): - with patch( - "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" - ) as mock_get_existing: - mock_get_existing.return_value = (s for s in ["Existing analysis"]) + assert response.status_code == 200 + assert b"Existing analysis" in response.data - response = execute_simulation_analysis("us") - assert response.status_code == 200 - assert b"Existing analysis" in response.data - - -def test_execute_simulation_analysis_new_analysis(app, rest_client): - with app.test_request_context(json=test_json): +def test_execute_simulation_analysis_new_analysis(rest_client): + with patch( + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" + ) as mock_get_existing: + mock_get_existing.return_value = None with patch( - "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" - ) as mock_get_existing: - mock_get_existing.return_value = None - with patch( - "policyengine_api.services.simulation_analysis_service.AIAnalysisService.trigger_ai_analysis" - ) as mock_trigger: - mock_trigger.return_value = (s for s in ["New analysis"]) + "policyengine_api.services.simulation_analysis_service.AIAnalysisService.trigger_ai_analysis" + ) as mock_trigger: + mock_trigger.return_value = (s for s in ["New analysis"]) - response = execute_simulation_analysis("us") + response = rest_client.post( + "/us/simulation-analysis", json=test_json + ) - assert response.status_code == 200 - assert b"New analysis" in response.data + assert response.status_code == 200 + assert b"New analysis" in response.data -def test_execute_simulation_analysis_error(app, rest_client): - with app.test_request_context(json=test_json): +def test_execute_simulation_analysis_error(rest_client): + with patch( + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" + ) as mock_get_existing: + mock_get_existing.return_value = None with patch( - "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" - ) as mock_get_existing: - mock_get_existing.return_value = None - with patch( - "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" - ) as mock_trigger: - mock_trigger.side_effect = Exception("Test error") + "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" + ) as mock_trigger: + mock_trigger.side_effect = Exception("Test error") - response = execute_simulation_analysis("us") + response = rest_client.post( + "/us/simulation-analysis", json=test_json + ) - assert response.status_code == 500 + assert response.status_code == 500 + assert b"Test error" in response.data -def test_execute_simulation_analysis_enhanced_cps(app, rest_client): +def test_execute_simulation_analysis_enhanced_cps(rest_client): policy_details = dict(policy_json="policy details") test_json_enhanced_us = { @@ -86,35 +81,36 @@ def test_execute_simulation_analysis_enhanced_cps(app, rest_client): ], "audience": "Normal", } - with app.test_request_context(json=test_json_enhanced_us): + with patch( + "policyengine_api.services.simulation_analysis_service.SimulationAnalysisService._generate_simulation_analysis_prompt" + ) as mock_generate_prompt: with patch( - "policyengine_api.services.simulation_analysis_service.SimulationAnalysisService._generate_simulation_analysis_prompt" - ) as mock_generate_prompt: + "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" + ) as mock_get_existing: + mock_get_existing.return_value = None with patch( - "policyengine_api.services.ai_analysis_service.AIAnalysisService.get_existing_analysis" - ) as mock_get_existing: - mock_get_existing.return_value = None - with patch( - "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" - ) as mock_trigger: - mock_trigger.return_value = ( - s for s in ["Enhanced CPS analysis"] - ) - - response = execute_simulation_analysis("us") - - assert response.status_code == 200 - assert b"Enhanced CPS analysis" in response.data - mock_generate_prompt.assert_called_once_with( - "2023", - "enhanced_us", - "USD", - policy_details, - test_impact, - ["param1", "param2"], - [{"param1": 100}, {"param2": 200}], - True, - "2023", - "us", - "Test Policy", - ) + "policyengine_api.services.ai_analysis_service.AIAnalysisService.trigger_ai_analysis" + ) as mock_trigger: + mock_trigger.return_value = ( + s for s in ["Enhanced CPS analysis"] + ) + + response = rest_client.post( + "/us/simulation-analysis", json=test_json_enhanced_us + ) + + assert response.status_code == 200 + assert b"Enhanced CPS analysis" in response.data + mock_generate_prompt.assert_called_once_with( + "2023", + "enhanced_us", + "USD", + policy_details, + test_impact, + ["param1", "param2"], + [{"param1": 100}, {"param2": 200}], + True, + "2023", + "us", + "Test Policy", + ) diff --git a/tests/python/test_tracer.py b/tests/python/test_tracer.py index 2e08bbaf..f5e81cf2 100644 --- a/tests/python/test_tracer.py +++ b/tests/python/test_tracer.py @@ -2,9 +2,6 @@ from flask import Flask, json from unittest.mock import patch -from policyengine_api.routes.tracer_analysis_routes import ( - execute_tracer_analysis, -) from policyengine_api.services.tracer_analysis_service import ( TracerAnalysisService, ) @@ -12,13 +9,6 @@ test_service = TracerAnalysisService() -@pytest.fixture -def app(): - app = Flask(__name__) - app.config["TESTING"] = True - return app - - # Test cases for parse_tracer_output function def test_parse_tracer_output(): @@ -51,7 +41,7 @@ def test_parse_tracer_output(): "policyengine_api.services.tracer_analysis_service.TracerAnalysisService.trigger_ai_analysis" ) def test_execute_tracer_analysis_success( - mock_trigger_ai_analysis, mock_db, app, rest_client + mock_trigger_ai_analysis, mock_db, rest_client ): mock_db.query.return_value.fetchone.return_value = { "tracer_output": json.dumps( @@ -64,33 +54,31 @@ def test_execute_tracer_analysis_success( # Set this to US current law test_policy_id = 2 - with app.test_request_context( - "/us/tracer_analysis", + response = rest_client.post( + "/us/tracer-analysis", json={ "household_id": test_household_id, "policy_id": test_policy_id, "variable": "disposable_income", }, - ): - response = execute_tracer_analysis("us") + ) assert response.status_code == 200 assert b"AI analysis result" in response.data @patch("policyengine_api.services.tracer_analysis_service.local_database") -def test_execute_tracer_analysis_no_tracer(mock_db, app, rest_client): +def test_execute_tracer_analysis_no_tracer(mock_db, rest_client): mock_db.query.return_value.fetchone.return_value = None - with app.test_request_context( - "/us/tracer_analysis", + response = rest_client.post( + "/us/tracer-analysis", json={ "household_id": "test_household", "policy_id": "test_policy", "variable": "disposable_income", }, - ): - response = execute_tracer_analysis("us") + ) assert response.status_code == 404 assert ( @@ -104,7 +92,7 @@ def test_execute_tracer_analysis_no_tracer(mock_db, app, rest_client): "policyengine_api.services.tracer_analysis_service.TracerAnalysisService.trigger_ai_analysis" ) def test_execute_tracer_analysis_ai_error( - mock_trigger_ai_analysis, mock_db, app, rest_client + mock_trigger_ai_analysis, mock_db, rest_client ): mock_db.query.return_value.fetchone.return_value = { "tracer_output": json.dumps( @@ -114,22 +102,20 @@ def test_execute_tracer_analysis_ai_error( mock_trigger_ai_analysis.side_effect = Exception(KeyError) test_household_id = 1500 - - # Set this to US current law test_policy_id = 2 - with app.test_request_context( - "/us/tracer_analysis", + # Use the test client to make the request instead of calling the function directly + response = rest_client.post( + "/us/tracer-analysis", json={ "household_id": test_household_id, "policy_id": test_policy_id, "variable": "disposable_income", }, - ): - response = execute_tracer_analysis("us") + ) assert response.status_code == 500 - assert "An error occurred" in json.loads(response.data)["message"] + assert json.loads(response.data)["status"] == "error" # Test invalid country @@ -142,5 +128,5 @@ def test_invalid_country(rest_client): "variable": "disposable_income", }, ) - assert response.status_code == 404 + assert response.status_code == 400 assert b"Country invalid_country not found" in response.data diff --git a/tests/python/test_units.py b/tests/python/test_units.py index 341a5094..53d82a40 100644 --- a/tests/python/test_units.py +++ b/tests/python/test_units.py @@ -1,8 +1,10 @@ -from policyengine_api.routes.metadata_routes import get_metadata +from policyengine_api.services.metadata_service import MetadataService + +metadata_service = MetadataService() def test_units(): - m = get_metadata("us") + m = metadata_service.get_metadata("us") assert ( m["result"]["parameters"][ "gov.states.md.tax.income.rates.head[0].rate" diff --git a/tests/python/test_user_profile.py b/tests/python/test_user_profile.py index 72b1c48d..52c29f09 100644 --- a/tests/python/test_user_profile.py +++ b/tests/python/test_user_profile.py @@ -38,7 +38,6 @@ def test_set_and_get_record(self, rest_client): res = rest_client.get(f"/us/user-profile?auth0_id={self.auth0_id}") return_object = json.loads(res.text) - print(return_object) assert res.status_code == 200 assert return_object["status"] == "ok" @@ -100,6 +99,5 @@ def test_non_existent_record(self, rest_client): f"/us/user-profile?auth0_id={non_existent_auth0_id}" ) return_object = json.loads(res.text) - print(return_object) assert res.status_code == 404 diff --git a/tests/python/test_validate_country.py b/tests/python/test_validate_country.py index b162e4df..d927198a 100644 --- a/tests/python/test_validate_country.py +++ b/tests/python/test_validate_country.py @@ -24,4 +24,4 @@ def test_valid_country(self): def test_invalid_country(self): result = foo("baz", "extra_arg") assert isinstance(result, Response) - assert result.status_code == 404 + assert result.status_code == 400 diff --git a/tests/python/test_yearly_var_removal.py b/tests/python/test_yearly_var_removal.py index 72433248..dc84b990 100644 --- a/tests/python/test_yearly_var_removal.py +++ b/tests/python/test_yearly_var_removal.py @@ -2,12 +2,14 @@ import json from policyengine_api.endpoints.household import get_household_under_policy -from policyengine_api.routes.metadata_routes import get_metadata +from policyengine_api.services.metadata_service import MetadataService from policyengine_api.endpoints.policy import get_policy from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS from policyengine_api.data import database from policyengine_api.api import app +metadata_service = MetadataService() + @pytest.fixture def client(): @@ -107,7 +109,7 @@ def interface_test_household_under_policy( is_test_passing = True # Fetch live country metadata - metadata = get_metadata(country_id)["result"] + metadata = metadata_service.get_metadata(country_id)["result"] # Create the test household on the local db instance create_test_household(TEST_HOUSEHOLD_ID, country_id) @@ -244,7 +246,7 @@ def test_get_calculate(client): excluded_vars = ["members"] # Fetch live country metadata - metadata = get_metadata(COUNTRY_ID)["result"] + metadata = metadata_service.get_metadata(COUNTRY_ID)["result"] with open( f"./tests/python/data/us_household.json", "r", encoding="utf-8"