From f195d64372c1cb1f147cb590a1d89a92c3ed8326 Mon Sep 17 00:00:00 2001 From: Uchechukwu Orji Date: Fri, 21 Jun 2024 00:02:42 +0100 Subject: [PATCH] split backend settings, use fully qualified import names --- backend/pyproject.toml | 4 +- .../src/mirrors_qa_backend/cryptography.py | 33 ++------- backend/src/mirrors_qa_backend/db/__init__.py | 7 +- backend/src/mirrors_qa_backend/db/country.py | 14 ++-- backend/src/mirrors_qa_backend/db/mirrors.py | 2 +- backend/src/mirrors_qa_backend/db/models.py | 10 ++- backend/src/mirrors_qa_backend/db/tests.py | 22 +++--- backend/src/mirrors_qa_backend/db/worker.py | 70 ++++++++----------- backend/src/mirrors_qa_backend/entrypoint.py | 9 +-- backend/src/mirrors_qa_backend/enums.py | 2 +- backend/src/mirrors_qa_backend/exceptions.py | 6 ++ backend/src/mirrors_qa_backend/extract.py | 4 +- backend/src/mirrors_qa_backend/main.py | 6 +- .../88e49e681048_add_country_code_to_tests.py | 40 +++++++++++ backend/src/mirrors_qa_backend/routes/auth.py | 45 ++++++------ .../mirrors_qa_backend/routes/dependencies.py | 26 +++---- .../src/mirrors_qa_backend/routes/tests.py | 34 +++++---- backend/src/mirrors_qa_backend/scheduler.py | 36 +++++----- backend/src/mirrors_qa_backend/schemas.py | 2 +- backend/src/mirrors_qa_backend/serializer.py | 2 +- .../{settings.py => settings/__init__.py} | 21 ++---- .../src/mirrors_qa_backend/settings/api.py | 11 +++ .../mirrors_qa_backend/settings/scheduler.py | 12 ++++ backend/src/mirrors_qa_backend/tokens.py | 17 +++++ backend/tests/conftest.py | 69 ++++++++++++------ backend/tests/db/test_mirrors.py | 28 ++++---- backend/tests/db/test_tests.py | 41 ++++++----- backend/tests/db/test_worker.py | 24 ++++--- backend/tests/routes/test_auth_endpoints.py | 6 +- dev/docker-compose.yaml | 1 - 30 files changed, 359 insertions(+), 245 deletions(-) create mode 100644 backend/src/mirrors_qa_backend/migrations/versions/88e49e681048_add_country_code_to_tests.py rename backend/src/mirrors_qa_backend/{settings.py => settings/__init__.py} (50%) create mode 100644 backend/src/mirrors_qa_backend/settings/api.py create mode 100644 backend/src/mirrors_qa_backend/settings/scheduler.py create mode 100644 backend/src/mirrors_qa_backend/tokens.py diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9a8405a..2329e21 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -190,6 +190,8 @@ ignore = [ "S603", # Ignore complexity "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + # Ignore warnings on missing timezone info + "DTZ005", "DTZ001", "DTZ006", ] unfixable = [ # Don't touch unused imports @@ -216,7 +218,7 @@ testpaths = ["tests"] pythonpath = [".", "src"] addopts = "--strict-markers" markers = [ - "num_tests(num=10, *, status=..., country=...): create num tests in the database using status and/or country. Random data is chosen for country or status if either is not set", + "num_tests(num=10, *, status=..., country_code=...): create num tests in the database using status and/or country_code. Random data is chosen for country_code or status if either is not set", ] [tool.coverage.paths] diff --git a/backend/src/mirrors_qa_backend/cryptography.py b/backend/src/mirrors_qa_backend/cryptography.py index 3dee16d..dd1f161 100644 --- a/backend/src/mirrors_qa_backend/cryptography.py +++ b/backend/src/mirrors_qa_backend/cryptography.py @@ -1,15 +1,13 @@ # pyright: strict, reportGeneralTypeIssues=false -import datetime +from pathlib import Path -import jwt import paramiko from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError -from mirrors_qa_backend.settings import Settings def verify_signed_message(public_key: bytes, signature: bytes, message: bytes) -> bool: @@ -45,16 +43,11 @@ def sign_message(private_key: RSAPrivateKey, message: bytes) -> bytes: ) -def generate_private_key(key_size: int = 2048) -> RSAPrivateKey: - return rsa.generate_private_key(public_exponent=65537, key_size=key_size) - - -def serialize_private_key(private_key: RSAPrivateKey) -> bytes: - return private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) +def load_private_key_from_path(private_key_fpath: Path) -> RSAPrivateKey: + with private_key_fpath.open("rb") as key_file: + return serialization.load_pem_private_key( + key_file.read(), password=None + ) # pyright: ignore[reportReturnType] def generate_public_key(private_key: RSAPrivateKey) -> RSAPublicKey: @@ -73,15 +66,3 @@ def get_public_key_fingerprint(public_key: RSAPublicKey) -> str: return paramiko.RSAKey( key=public_key ).fingerprint # pyright: ignore[reportUnknownMemberType, UnknownVariableType] - - -def generate_access_token(worker_id: str) -> str: - issue_time = datetime.datetime.now(datetime.UTC) - expire_time = issue_time + datetime.timedelta(hours=Settings.TOKEN_EXPIRY) - payload = { - "iss": "mirrors-qa-backend", # issuer - "exp": expire_time.timestamp(), # expiration time - "iat": issue_time.timestamp(), # issued at - "subject": worker_id, - } - return jwt.encode(payload, key=Settings.JWT_SECRET, algorithm="HS256") diff --git a/backend/src/mirrors_qa_backend/db/__init__.py b/backend/src/mirrors_qa_backend/db/__init__.py index 4336f88..688f3ae 100644 --- a/backend/src/mirrors_qa_backend/db/__init__.py +++ b/backend/src/mirrors_qa_backend/db/__init__.py @@ -7,7 +7,8 @@ from sqlalchemy.orm import sessionmaker from mirrors_qa_backend import logger -from mirrors_qa_backend.db import mirrors, models +from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status from mirrors_qa_backend.extract import get_current_mirrors from mirrors_qa_backend.settings import Settings @@ -46,14 +47,14 @@ def initialize_mirrors() -> None: if not current_mirrors: logger.info(f"No mirrors were found on {Settings.MIRRORS_URL!r}") return - result = mirrors.create_or_update_status(session, current_mirrors) + result = create_or_update_mirror_status(session, current_mirrors) logger.info( f"Registered {result.nb_mirrors_added} mirrors " f"from {Settings.MIRRORS_URL!r}" ) else: logger.info(f"Found {nb_mirrors} mirrors in database.") - result = mirrors.create_or_update_status(session, current_mirrors) + result = create_or_update_mirror_status(session, current_mirrors) logger.info( f"Added {result.nb_mirrors_added} mirrors. " f"Disabled {result.nb_mirrors_disabled} mirrors." diff --git a/backend/src/mirrors_qa_backend/db/country.py b/backend/src/mirrors_qa_backend/db/country.py index e4ecc25..a3c173c 100644 --- a/backend/src/mirrors_qa_backend/db/country.py +++ b/backend/src/mirrors_qa_backend/db/country.py @@ -1,12 +1,16 @@ from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.models import Country -def get_countries_by_name(session: OrmSession, *countries: str) -> list[models.Country]: +def get_countries(session: OrmSession, *country_codes: str) -> list[Country]: return list( - session.scalars( - select(models.Country).where(models.Country.name.in_(countries)) - ).all() + session.scalars(select(Country).where(Country.code.in_(country_codes))).all() ) + + +def get_country_or_none(session: OrmSession, country_code: str) -> Country | None: + return session.scalars( + select(Country).where(Country.code == country_code) + ).one_or_none() diff --git a/backend/src/mirrors_qa_backend/db/mirrors.py b/backend/src/mirrors_qa_backend/db/mirrors.py index 8e07777..c38e2c3 100644 --- a/backend/src/mirrors_qa_backend/db/mirrors.py +++ b/backend/src/mirrors_qa_backend/db/mirrors.py @@ -56,7 +56,7 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int: return nb_created -def create_or_update_status( +def create_or_update_mirror_status( session: OrmSession, mirrors: list[schemas.Mirror] ) -> MirrorsUpdateResult: """Updates the status of mirrors in the database and creates any new mirrors. diff --git a/backend/src/mirrors_qa_backend/db/models.py b/backend/src/mirrors_qa_backend/db/models.py index 3ad3a35..125095e 100644 --- a/backend/src/mirrors_qa_backend/db/models.py +++ b/backend/src/mirrors_qa_backend/db/models.py @@ -64,6 +64,8 @@ class Country(Base): cascade="all, delete-orphan", ) + tests: Mapped[list[Test]] = relationship(back_populates="country", init=False) + __table_args__ = (UniqueConstraint("name", "code"),) @@ -131,7 +133,11 @@ class Test(Base): ip_address: Mapped[IPv4Address | None] = mapped_column(default=None) # autonomous system based on IP asn: Mapped[str | None] = mapped_column(default=None) - country: Mapped[str | None] = mapped_column(default=None) # country based on IP + country_code: Mapped[str | None] = mapped_column( + ForeignKey("country.code"), + init=False, + default=None, + ) location: Mapped[str | None] = mapped_column(default=None) # city based on IP latency: Mapped[int | None] = mapped_column(default=None) # milliseconds download_size: Mapped[int | None] = mapped_column(default=None) # bytes @@ -142,3 +148,5 @@ class Test(Base): ) worker: Mapped[Worker | None] = relationship(back_populates="tests", init=False) + + country: Mapped[Country | None] = relationship(back_populates="tests", init=False) diff --git a/backend/src/mirrors_qa_backend/db/tests.py b/backend/src/mirrors_qa_backend/db/tests.py index c14587e..e8e9013 100644 --- a/backend/src/mirrors_qa_backend/db/tests.py +++ b/backend/src/mirrors_qa_backend/db/tests.py @@ -1,4 +1,3 @@ -# ruff: noqa: DTZ005, DTZ001 import datetime from dataclasses import dataclass from ipaddress import IPv4Address @@ -8,6 +7,7 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.country import get_country_or_none from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum from mirrors_qa_backend.settings import Settings @@ -25,7 +25,7 @@ def filter_test( test: models.Test, *, worker_id: str | None = None, - country: str | None = None, + country_code: str | None = None, statuses: list[StatusEnum] | None = None, ) -> bool: """Checks if a test has the same attribute as the provided attribute. @@ -35,7 +35,7 @@ def filter_test( """ if worker_id is not None and test.worker_id != worker_id: return False - if country is not None and test.country != country: + if country_code is not None and test.country_code != country_code: return False if statuses is not None and test.status not in statuses: return False @@ -52,7 +52,7 @@ def list_tests( session: OrmSession, *, worker_id: str | None = None, - country: str | None = None, + country_code: str | None = None, statuses: list[StatusEnum] | None = None, page_num: int = 1, page_size: int = Settings.MAX_PAGE_SIZE, @@ -88,7 +88,7 @@ def list_tests( select(func.count().over().label("total_records"), models.Test) .where( (models.Test.worker_id == worker_id) | (worker_id is None), - (models.Test.country == country) | (country is None), + (models.Test.country_code == country_code) | (country_code is None), (models.Test.status.in_(statuses)), ) .order_by(*order_by) @@ -114,7 +114,7 @@ def create_or_update_test( error: str | None = None, ip_address: IPv4Address | None = None, asn: str | None = None, - country: str | None = None, + country_code: str | None = None, location: str | None = None, latency: int | None = None, download_size: int | None = None, @@ -136,7 +136,9 @@ def create_or_update_test( test.error = error if error else test.error test.ip_address = ip_address if ip_address else test.ip_address test.asn = asn if asn else test.asn - test.country = country if country else test.country + test.country = ( + get_country_or_none(session, country_code) if country_code else test.country + ) test.location = location if location else test.location test.latency = latency if latency else test.latency test.download_size = download_size if download_size else test.download_size @@ -158,7 +160,7 @@ def create_test( error: str | None = None, ip_address: IPv4Address | None = None, asn: str | None = None, - country: str | None = None, + country_code: str | None = None, location: str | None = None, latency: int | None = None, download_size: int | None = None, @@ -174,7 +176,7 @@ def create_test( error=error, ip_address=ip_address, asn=asn, - country=country, + country_code=country_code, location=location, latency=latency, download_size=download_size, @@ -189,7 +191,7 @@ def expire_tests( ) -> list[models.Test]: """Change the status of PENDING tests created before the interval to MISSED""" end = datetime.datetime.now() - interval - begin = datetime.datetime(1970, 1, 1) + begin = datetime.datetime.fromtimestamp(0) return list( session.scalars( update(models.Test) diff --git a/backend/src/mirrors_qa_backend/db/worker.py b/backend/src/mirrors_qa_backend/db/worker.py index e5f5fad..acc0f37 100644 --- a/backend/src/mirrors_qa_backend/db/worker.py +++ b/backend/src/mirrors_qa_backend/db/worker.py @@ -1,55 +1,51 @@ -# ruff: noqa: DTZ005, DTZ001 import datetime from pathlib import Path from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend import cryptography -from mirrors_qa_backend.db import country, models +from mirrors_qa_backend.cryptography import ( + generate_public_key, + get_public_key_fingerprint, + load_private_key_from_path, + serialize_public_key, +) +from mirrors_qa_backend.db.country import get_countries from mirrors_qa_backend.db.exceptions import DuplicatePrimaryKeyError +from mirrors_qa_backend.db.models import Worker +from mirrors_qa_backend.exceptions import PEMPrivateKeyLoadError -def get_worker(session: OrmSession, worker_id: str) -> models.Worker | None: - return session.scalars( - select(models.Worker).where(models.Worker.id == worker_id) - ).one_or_none() +def get_worker(session: OrmSession, worker_id: str) -> Worker | None: + return session.scalars(select(Worker).where(Worker.id == worker_id)).one_or_none() def create_worker( session: OrmSession, worker_id: str, - countries: list[str], - private_key_filename: str | Path | None = None, -) -> models.Worker: - """Creates a worker and writes private key contents to private_key_filename. - - If no private_key_filename is provided, defaults to {worker_id}.pem. - """ + country_codes: list[str], + private_key_fpath: Path, +) -> Worker: + """Creates a worker using RSA private key.""" if get_worker(session, worker_id) is not None: raise DuplicatePrimaryKeyError( f"A worker with id {worker_id!r} already exists." ) - - if private_key_filename is None: - private_key_filename = f"{worker_id}.pem" - - private_key = cryptography.generate_private_key() - public_key = cryptography.generate_public_key(private_key) - public_key_pkcs8 = cryptography.serialize_public_key(public_key).decode( - encoding="ascii" - ) - with open(private_key_filename, "wb") as fp: - fp.write(cryptography.serialize_private_key(private_key)) - - worker = models.Worker( + try: + private_key = load_private_key_from_path(private_key_fpath) + except Exception as exc: + raise PEMPrivateKeyLoadError("unable to load private key from file") from exc + + public_key = generate_public_key(private_key) + public_key_pkcs8 = serialize_public_key(public_key).decode(encoding="ascii") + worker = Worker( id=worker_id, pubkey_pkcs8=public_key_pkcs8, - pubkey_fingerprint=cryptography.get_public_key_fingerprint(public_key), + pubkey_fingerprint=get_public_key_fingerprint(public_key), ) session.add(worker) - for db_country in country.get_countries_by_name(session, *countries): + for db_country in get_countries(session, *country_codes): db_country.worker_id = worker_id session.add(db_country) @@ -58,20 +54,18 @@ def create_worker( def get_workers_last_seen_in_range( session: OrmSession, begin: datetime.datetime, end: datetime.datetime -) -> list[models.Worker]: +) -> list[Worker]: """Get workers whose last_seen_on falls between begin and end dates""" return list( session.scalars( - select(models.Worker).where( - models.Worker.last_seen_on.between(begin, end), + select(Worker).where( + Worker.last_seen_on.between(begin, end), ) ).all() ) -def get_idle_workers( - session: OrmSession, interval: datetime.timedelta -) -> list[models.Worker]: +def get_idle_workers(session: OrmSession, interval: datetime.timedelta) -> list[Worker]: end = datetime.datetime.now() - interval begin = datetime.datetime(1970, 1, 1) return get_workers_last_seen_in_range(session, begin, end) @@ -79,15 +73,13 @@ def get_idle_workers( def get_active_workers( session: OrmSession, interval: datetime.timedelta -) -> list[models.Worker]: +) -> list[Worker]: end = datetime.datetime.now() begin = end - interval return get_workers_last_seen_in_range(session, begin, end) -def update_worker_last_seen( - session: OrmSession, worker: models.Worker -) -> models.Worker: +def update_worker_last_seen(session: OrmSession, worker: Worker) -> Worker: worker.last_seen_on = datetime.datetime.now() session.add(worker) return worker diff --git a/backend/src/mirrors_qa_backend/entrypoint.py b/backend/src/mirrors_qa_backend/entrypoint.py index c8699c7..69f314e 100644 --- a/backend/src/mirrors_qa_backend/entrypoint.py +++ b/backend/src/mirrors_qa_backend/entrypoint.py @@ -1,8 +1,9 @@ import argparse import logging -from mirrors_qa_backend import db, logger -from mirrors_qa_backend.db import mirrors +from mirrors_qa_backend import logger +from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status from mirrors_qa_backend.extract import get_current_mirrors @@ -17,8 +18,8 @@ def main(): if args.verbose: logger.setLevel(logging.DEBUG) - with db.Session.begin() as session: - mirrors.create_or_update_status(session, get_current_mirrors()) + with Session.begin() as session: + create_or_update_mirror_status(session, get_current_mirrors()) if __name__ == "__main__": diff --git a/backend/src/mirrors_qa_backend/enums.py b/backend/src/mirrors_qa_backend/enums.py index 7dab556..bc3369f 100644 --- a/backend/src/mirrors_qa_backend/enums.py +++ b/backend/src/mirrors_qa_backend/enums.py @@ -17,7 +17,7 @@ class TestSortColumnEnum(Enum): started_on = "started_on" status = "status" worker_id = "worker_id" - country = "country" + country_code = "country_code" city = "city" diff --git a/backend/src/mirrors_qa_backend/exceptions.py b/backend/src/mirrors_qa_backend/exceptions.py index 0c2f85c..0398309 100644 --- a/backend/src/mirrors_qa_backend/exceptions.py +++ b/backend/src/mirrors_qa_backend/exceptions.py @@ -17,3 +17,9 @@ class PEMPublicKeyLoadError(Exception): """Unable to deserialize a public key from PEM encoded data""" pass + + +class PEMPrivateKeyLoadError(Exception): + """Unable to deserialize a private key from PEM encoded data""" + + pass diff --git a/backend/src/mirrors_qa_backend/extract.py b/backend/src/mirrors_qa_backend/extract.py index f24ba69..f5ab4bf 100644 --- a/backend/src/mirrors_qa_backend/extract.py +++ b/backend/src/mirrors_qa_backend/extract.py @@ -28,7 +28,9 @@ def is_country_row(tag: Tag) -> bool: return tag.name == "tr" and tag.findChild("td", class_="newregion") is None try: - resp = requests.get(Settings.MIRRORS_URL, timeout=Settings.REQUESTS_TIMEOUT) + resp = requests.get( + Settings.MIRRORS_URL, timeout=Settings.REQUESTS_TIMEOUT_SECONDS + ) resp.raise_for_status() except requests.RequestException as exc: raise MirrorsRequestError( diff --git a/backend/src/mirrors_qa_backend/main.py b/backend/src/mirrors_qa_backend/main.py index df17118..06933e7 100644 --- a/backend/src/mirrors_qa_backend/main.py +++ b/backend/src/mirrors_qa_backend/main.py @@ -2,14 +2,14 @@ from fastapi import FastAPI -from mirrors_qa_backend import db +from mirrors_qa_backend.db import initialize_mirrors, upgrade_db_schema from mirrors_qa_backend.routes import auth, tests @asynccontextmanager async def lifespan(_: FastAPI): - db.upgrade_db_schema() - db.initialize_mirrors() + upgrade_db_schema() + initialize_mirrors() yield diff --git a/backend/src/mirrors_qa_backend/migrations/versions/88e49e681048_add_country_code_to_tests.py b/backend/src/mirrors_qa_backend/migrations/versions/88e49e681048_add_country_code_to_tests.py new file mode 100644 index 0000000..7175a56 --- /dev/null +++ b/backend/src/mirrors_qa_backend/migrations/versions/88e49e681048_add_country_code_to_tests.py @@ -0,0 +1,40 @@ +"""add country code to tests + +Revision ID: 88e49e681048 +Revises: 5c376f6fb191 +Create Date: 2024-06-20 21:43:32.830017 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "88e49e681048" +down_revision = "5c376f6fb191" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("test", sa.Column("country_code", sa.String(), nullable=True)) + op.create_foreign_key( + op.f("fk_test_country_code_country"), + "test", + "country", + ["country_code"], + ["code"], + ) + op.drop_column("test", "country") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "test", sa.Column("country", sa.VARCHAR(), autoincrement=False, nullable=True) + ) + op.drop_constraint(op.f("fk_test_country_code_country"), "test", type_="foreignkey") + op.drop_column("test", "country_code") + # ### end Alembic commands ### diff --git a/backend/src/mirrors_qa_backend/routes/auth.py b/backend/src/mirrors_qa_backend/routes/auth.py index 59ace0e..1463f78 100644 --- a/backend/src/mirrors_qa_backend/routes/auth.py +++ b/backend/src/mirrors_qa_backend/routes/auth.py @@ -5,12 +5,19 @@ from fastapi import APIRouter, Header -from mirrors_qa_backend import cryptography, logger, schemas -from mirrors_qa_backend.db import worker +from mirrors_qa_backend import logger +from mirrors_qa_backend.cryptography import verify_signed_message +from mirrors_qa_backend.db.worker import get_worker from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError -from mirrors_qa_backend.routes import http_errors from mirrors_qa_backend.routes.dependencies import DbSession -from mirrors_qa_backend.settings import Settings +from mirrors_qa_backend.routes.http_errors import ( + BadRequestError, + ForbiddenError, + UnauthorizedError, +) +from mirrors_qa_backend.schemas import Token +from mirrors_qa_backend.settings.api import APISettings +from mirrors_qa_backend.tokens import generate_access_token router = APIRouter(prefix="/auth", tags=["auth"]) @@ -25,52 +32,50 @@ def authenticate_worker( x_sshauth_signature: Annotated[ str, Header(description="signature, base64-encoded") ], -) -> schemas.Token: +) -> Token: """Authenticate using signed message and generate tokens.""" try: signature = base64.standard_b64decode(x_sshauth_signature) except binascii.Error as exc: - raise http_errors.BadRequestError( - "Invalid signature format (not base64)" - ) from exc + raise BadRequestError("Invalid signature format (not base64)") from exc try: # decode message: worker_id:timestamp(UTC ISO) worker_id, timestamp_str = x_sshauth_message.split(":", 1) timestamp = datetime.datetime.fromisoformat(timestamp_str) except ValueError as exc: - raise http_errors.BadRequestError("Invalid message format.") from exc + raise BadRequestError("Invalid message format.") from exc # verify timestamp is less than MESSAGE_VALIDITY if ( datetime.datetime.now(datetime.UTC) - timestamp - ).total_seconds() > Settings.MESSAGE_VALIDITY: - raise http_errors.UnauthorizedError( + ).total_seconds() > APISettings.MESSAGE_VALIDITY: + raise UnauthorizedError( "Difference betweeen message time and server time is " - f"greater than {Settings.MESSAGE_VALIDITY}s" + f"greater than {APISettings.MESSAGE_VALIDITY}s" ) # verify worker with worker_id exists in database - db_worker = worker.get_worker(session, worker_id) + db_worker = get_worker(session, worker_id) if db_worker is None: - raise http_errors.UnauthorizedError() + raise UnauthorizedError() # verify signature of message with worker's public keys try: - if not cryptography.verify_signed_message( + if not verify_signed_message( bytes(db_worker.pubkey_pkcs8, encoding="ascii"), signature, bytes(x_sshauth_message, encoding="ascii"), ): - raise http_errors.UnauthorizedError() + raise UnauthorizedError() except PEMPublicKeyLoadError as exc: logger.exception("error while verifying message using public key") - raise http_errors.ForbiddenError("Unable to load public_key") from exc + raise ForbiddenError("Unable to load public_key") from exc # generate tokens - access_token = cryptography.generate_access_token(worker_id) - return schemas.Token( + access_token = generate_access_token(worker_id) + return Token( access_token=access_token, token_type="bearer", - expires_in=datetime.timedelta(hours=Settings.TOKEN_EXPIRY).total_seconds(), + expires_in=datetime.timedelta(hours=APISettings.TOKEN_EXPIRY).total_seconds(), ) diff --git a/backend/src/mirrors_qa_backend/routes/dependencies.py b/backend/src/mirrors_qa_backend/routes/dependencies.py index 6d317aa..963c811 100644 --- a/backend/src/mirrors_qa_backend/routes/dependencies.py +++ b/backend/src/mirrors_qa_backend/routes/dependencies.py @@ -9,9 +9,11 @@ from sqlalchemy.orm import Session from mirrors_qa_backend import schemas -from mirrors_qa_backend.db import gen_dbsession, models, tests, worker -from mirrors_qa_backend.routes import http_errors -from mirrors_qa_backend.settings import Settings +from mirrors_qa_backend.db import gen_dbsession, models +from mirrors_qa_backend.db.tests import get_test as db_get_test +from mirrors_qa_backend.db.worker import get_worker +from mirrors_qa_backend.routes.http_errors import NotFoundError, UnauthorizedError +from mirrors_qa_backend.settings.api import APISettings DbSession = Annotated[Session, Depends(gen_dbsession)] @@ -25,22 +27,22 @@ def get_current_worker( ) -> models.Worker: token = authorization.credentials try: - jwt_claims = jwt.decode(token, Settings.JWT_SECRET, algorithms=["HS256"]) + jwt_claims = jwt.decode(token, APISettings.JWT_SECRET, algorithms=["HS256"]) except jwt_exceptions.ExpiredSignatureError as exc: - raise http_errors.UnauthorizedError("Token has expired.") from exc + raise UnauthorizedError("Token has expired.") from exc except (jwt_exceptions.InvalidTokenError, jwt_exceptions.PyJWTError) as exc: - raise http_errors.UnauthorizedError from exc + raise UnauthorizedError from exc try: claims = schemas.JWTClaims(**jwt_claims) except PydanticValidationError as exc: - raise http_errors.UnauthorizedError from exc + raise UnauthorizedError from exc # At this point, we know that the JWT is all OK and we can # trust the data in it. We extract the worker_id from the claims - db_worker = worker.get_worker(session, claims.subject) + db_worker = get_worker(session, claims.subject) if db_worker is None: - raise http_errors.UnauthorizedError() + raise UnauthorizedError() return db_worker @@ -49,9 +51,9 @@ def get_current_worker( def get_test(session: DbSession, test_id: Annotated[UUID4, Path()]) -> models.Test: """Fetches the test specified in the request.""" - test = tests.get_test(session, test_id) + test = db_get_test(session, test_id) if test is None: - raise http_errors.NotFoundError(f"Test with id {test_id} does not exist.") + raise NotFoundError(f"Test with id {test_id} does not exist.") return test @@ -60,4 +62,4 @@ def get_test(session: DbSession, test_id: Annotated[UUID4, Path()]) -> models.Te def verify_worker_owns_test(worker: CurrentWorker, test: RetrievedTest): if test.worker_id != worker.id: - raise http_errors.UnauthorizedError("Insufficient privileges to update test.") + raise UnauthorizedError("Insufficient privileges to update test.") diff --git a/backend/src/mirrors_qa_backend/routes/tests.py b/backend/src/mirrors_qa_backend/routes/tests.py index 3c54a58..a463368 100644 --- a/backend/src/mirrors_qa_backend/routes/tests.py +++ b/backend/src/mirrors_qa_backend/routes/tests.py @@ -3,8 +3,10 @@ from fastapi import APIRouter, Depends, Query from fastapi import status as status_codes -from mirrors_qa_backend import schemas, serializer -from mirrors_qa_backend.db import tests, worker +from mirrors_qa_backend import schemas +from mirrors_qa_backend.db.tests import create_or_update_test +from mirrors_qa_backend.db.tests import list_tests as db_list_tests +from mirrors_qa_backend.db.worker import update_worker_last_seen from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum from mirrors_qa_backend.routes.dependencies import ( CurrentWorker, @@ -12,6 +14,8 @@ RetrievedTest, verify_worker_owns_test, ) +from mirrors_qa_backend.schemas import Test, TestsList, calculate_pagination_metadata +from mirrors_qa_backend.serializer import serialize_test from mirrors_qa_backend.settings import Settings router = APIRouter(prefix="/tests", tags=["tests"]) @@ -27,7 +31,7 @@ def list_tests( session: DbSession, worker_id: Annotated[str | None, Query()] = None, - country: Annotated[str | None, Query(min_length=3)] = None, + country_code: Annotated[str | None, Query(min_length=2, max_length=2)] = None, status: Annotated[list[StatusEnum] | None, Query()] = None, page_size: Annotated[ int, Query(le=Settings.MAX_PAGE_SIZE, ge=1) @@ -35,11 +39,11 @@ def list_tests( page_num: Annotated[int, Query(ge=1)] = 1, sort_by: Annotated[TestSortColumnEnum, Query()] = TestSortColumnEnum.requested_on, order: Annotated[SortDirectionEnum, Query()] = SortDirectionEnum.asc, -) -> schemas.TestsList: - result = tests.list_tests( +) -> TestsList: + result = db_list_tests( session, worker_id=worker_id, - country=country, + country_code=country_code, statuses=status, page_size=page_size, page_num=page_num, @@ -47,8 +51,8 @@ def list_tests( sort_direction=order, ) return schemas.TestsList( - tests=[serializer.serialize_test(test) for test in result.tests], - metadata=schemas.calculate_pagination_metadata( + tests=[serialize_test(test) for test in result.tests], + metadata=calculate_pagination_metadata( result.nb_tests, page_size=page_size, current_page=page_num ), ) @@ -64,8 +68,8 @@ def list_tests( }, }, ) -def get_test(test: RetrievedTest) -> schemas.Test: - return serializer.serialize_test(test) +def get_test(test: RetrievedTest) -> Test: + return serialize_test(test) @router.patch( @@ -81,10 +85,10 @@ def update_test( current_worker: CurrentWorker, test: RetrievedTest, update: schemas.UpdateTestModel, -) -> schemas.Test: +) -> Test: data = update.model_dump(exclude_unset=True) body = schemas.UpdateTestModel().model_copy(update=data) - updated_test = tests.create_or_update_test( + updated_test = create_or_update_test( session, test_id=test.id, worker_id=current_worker.id, @@ -92,12 +96,12 @@ def update_test( error=body.error, ip_address=body.ip_address, asn=body.asn, - country=body.country, + country_code=body.country_code, location=body.location, latency=body.latency, download_size=body.download_size, duration=body.duration, speed=body.speed, ) - worker.update_worker_last_seen(session, current_worker) - return serializer.serialize_test(updated_test) + update_worker_last_seen(session, current_worker) + return serialize_test(updated_test) diff --git a/backend/src/mirrors_qa_backend/scheduler.py b/backend/src/mirrors_qa_backend/scheduler.py index fbc527b..bc741a7 100644 --- a/backend/src/mirrors_qa_backend/scheduler.py +++ b/backend/src/mirrors_qa_backend/scheduler.py @@ -2,29 +2,31 @@ import time from mirrors_qa_backend import logger -from mirrors_qa_backend.db import Session, tests, worker +from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.tests import create_test, expire_tests, list_tests +from mirrors_qa_backend.db.worker import get_idle_workers from mirrors_qa_backend.enums import StatusEnum -from mirrors_qa_backend.settings import Settings +from mirrors_qa_backend.settings.scheduler import SchedulerSettings def main(): while True: with Session.begin() as session: - # expire tesst whose results have not been reported - expired_tests = tests.expire_tests( + # expire tests whose results have not been reported + expired_tests = expire_tests( session, - interval=datetime.timedelta(hours=Settings.EXPIRE_TEST_INTERVAL), + interval=datetime.timedelta(hours=SchedulerSettings.EXPIRE_TEST_HOURS), ) for expired_test in expired_tests: logger.info( f"Expired test {expired_test.id}, " - f"country: {expired_test.country}, " - f"worker: {expired_test.worker_id}" + f"country: {expired_test.country_code}, " + f"worker: {expired_test.worker_id!r}" ) - idle_workers = worker.get_idle_workers( + idle_workers = get_idle_workers( session, - interval=datetime.timedelta(hours=Settings.IDLE_WORKER_INTERVAL), + interval=datetime.timedelta(hours=SchedulerSettings.IDLE_WORKER_HOURS), ) if not idle_workers: logger.info("No idle workers found.") @@ -33,7 +35,7 @@ def main(): for idle_worker in idle_workers: if not idle_worker.countries: logger.info( - f"No countries registered for idle worker {idle_worker.id}" + f"No countries registered for idle worker {idle_worker.id!r}" ) continue for country in idle_worker.countries: @@ -41,25 +43,25 @@ def main(): # a test for a country might still be PENDING as the interval # for expiration and that of the scheduler might overlap. # In such scenarios, we skip creating a test for that country. - pending_tests = tests.list_tests( + pending_tests = list_tests( session, worker_id=idle_worker.id, statuses=[StatusEnum.PENDING], - country=country.name, + country_code=country.code, ) if pending_tests.nb_tests: logger.info( "Skipping creation of new test entries for " - f"{idle_worker.id} as {pending_tests.nb_tests} " - "tests are still pending." + f"{idle_worker.id!r} as {pending_tests.nb_tests} " + f"tests are still pending for country {country.name}" ) continue - new_test = tests.create_test( + new_test = create_test( session=session, worker_id=idle_worker.id, - country=country.name, + country_code=country.code, status=StatusEnum.PENDING, ) logger.info( @@ -68,7 +70,7 @@ def main(): ) sleep_interval = datetime.timedelta( - hours=Settings.SCHEDULER_SLEEP_INTERVAL + hours=SchedulerSettings.SCHEDULER_SLEEP_HOURS ).total_seconds() logger.info(f"Sleeping for {sleep_interval} seconds.") diff --git a/backend/src/mirrors_qa_backend/schemas.py b/backend/src/mirrors_qa_backend/schemas.py index 16d61cc..33583b5 100644 --- a/backend/src/mirrors_qa_backend/schemas.py +++ b/backend/src/mirrors_qa_backend/schemas.py @@ -39,7 +39,7 @@ class UpdateTestModel(BaseModel): isp: str | None = None ip_address: IPv4Address | None = None asn: str | None = None - country: str | None = None + country_code: str | None = None location: str | None = None latency: int | None = None download_size: int | None = None diff --git a/backend/src/mirrors_qa_backend/serializer.py b/backend/src/mirrors_qa_backend/serializer.py index c9b1b63..fd4d907 100644 --- a/backend/src/mirrors_qa_backend/serializer.py +++ b/backend/src/mirrors_qa_backend/serializer.py @@ -12,7 +12,7 @@ def serialize_test(test: models.Test) -> schemas.Test: isp=test.isp, ip_address=test.ip_address, asn=test.asn, - country=test.country, + country_code=test.country_code, location=test.location, latency=test.latency, download_size=test.download_size, diff --git a/backend/src/mirrors_qa_backend/settings.py b/backend/src/mirrors_qa_backend/settings/__init__.py similarity index 50% rename from backend/src/mirrors_qa_backend/settings.py rename to backend/src/mirrors_qa_backend/settings/__init__.py index fda5ba0..9e27132 100644 --- a/backend/src/mirrors_qa_backend/settings.py +++ b/backend/src/mirrors_qa_backend/settings/__init__.py @@ -15,6 +15,11 @@ class Settings: """Shared backend configuration""" DATABASE_URL: str = getenv("POSTGRES_URI", mandatory=True) + DEBUG = bool(getenv("DEBUG", default=False)) + # number of seconds before requests time out + REQUESTS_TIMEOUT_SECONDS = int(getenv("REQUESTS_TIMEOUT_SECONDS", default=5)) + # maximum number of items to return from a request/query + MAX_PAGE_SIZE = int(getenv("PAGE_SIZE", default=20)) # url to fetch the list of mirrors MIRRORS_URL: str = getenv( "MIRRORS_LIST_URL", default="https://download.kiwix.org/mirrors.html" @@ -23,19 +28,3 @@ class Settings: MIRRORS_EXCLUSION_LIST = getenv( "EXCLUDED_MIRRORS", default="mirror.isoc.org.il" ).split(",") - DEBUG = bool(getenv("DEBUG", default=False)) - # number of seconds before requests time out - REQUESTS_TIMEOUT = int(getenv("REQUESTS_TIMEOUT", default=5)) - # maximum number of items to return from a request - MAX_PAGE_SIZE = int(getenv("PAGE_SIZE", default=20)) - # number of seconds before a message expire - MESSAGE_VALIDITY = int(getenv("MESSAGE_VALIDITY", default=60)) - # number of hours before access tokens expire - TOKEN_EXPIRY = int(getenv("TOKEN_EXPIRY", default=24)) - JWT_SECRET: str = getenv("JWT_SECRET", mandatory=True) - # number of hours the scheduler sleeps before attempting to create tests - SCHEDULER_SLEEP_INTERVAL = int(getenv("SCHEDULER_SLEEP_INTERVAL", default=3)) - # number of hours into the past to determine if a worker is idle - IDLE_WORKER_INTERVAL = int(getenv("IDLE_WORKER_INTERVAL", default=1)) - # number of hours to wait before expiring a test whose data never arrived - EXPIRE_TEST_INTERVAL = int(getenv("EXPIRE_TEST_INTERVAL", default=24)) diff --git a/backend/src/mirrors_qa_backend/settings/api.py b/backend/src/mirrors_qa_backend/settings/api.py new file mode 100644 index 0000000..f61b76b --- /dev/null +++ b/backend/src/mirrors_qa_backend/settings/api.py @@ -0,0 +1,11 @@ +from mirrors_qa_backend.settings import Settings, getenv + + +class APISettings(Settings): + """Backend API settings""" + + JWT_SECRET: str = getenv("JWT_SECRET", mandatory=True) + # number of seconds before a message expire + MESSAGE_VALIDITY = int(getenv("MESSAGE_VALIDITY", default=60)) + # number of hours before access tokens expire + TOKEN_EXPIRY = int(getenv("TOKEN_EXPIRY", default=24)) diff --git a/backend/src/mirrors_qa_backend/settings/scheduler.py b/backend/src/mirrors_qa_backend/settings/scheduler.py new file mode 100644 index 0000000..afd39c7 --- /dev/null +++ b/backend/src/mirrors_qa_backend/settings/scheduler.py @@ -0,0 +1,12 @@ +from mirrors_qa_backend.settings import Settings, getenv + + +class SchedulerSettings(Settings): + """Scheduler settings""" + + # number of hours the scheduler sleeps before attempting to create tests + SCHEDULER_SLEEP_HOURS = int(getenv("SCHEDULER_SLEEP_INTERVAL", default=3)) + # number of hours into the past to determine if a worker is idle + IDLE_WORKER_HOURS = int(getenv("IDLE_WORKER_INTERVAL", default=1)) + # number of hours to wait before expiring a test whose data never arrived + EXPIRE_TEST_HOURS = int(getenv("EXPIRE_TEST_INTERVAL", default=24)) diff --git a/backend/src/mirrors_qa_backend/tokens.py b/backend/src/mirrors_qa_backend/tokens.py new file mode 100644 index 0000000..6ed6a79 --- /dev/null +++ b/backend/src/mirrors_qa_backend/tokens.py @@ -0,0 +1,17 @@ +import datetime + +import jwt + +from mirrors_qa_backend.settings.api import APISettings + + +def generate_access_token(worker_id: str) -> str: + issue_time = datetime.datetime.now(datetime.UTC) + expire_time = issue_time + datetime.timedelta(hours=APISettings.TOKEN_EXPIRY) + payload = { + "iss": "mirrors-qa-backend", # issuer + "exp": expire_time.timestamp(), # expiration time + "iat": issue_time.timestamp(), # issued at + "subject": worker_id, + } + return jwt.encode(payload, key=APISettings.JWT_SECRET, algorithm="HS256") diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 1b7d867..d9d2296 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,6 +4,7 @@ from typing import Any import paramiko +import pycountry import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa @@ -13,7 +14,9 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.cryptography import sign_message -from mirrors_qa_backend.db import Session, models +from mirrors_qa_backend.db import Session +from mirrors_qa_backend.db.country import get_country_or_none +from mirrors_qa_backend.db.models import Base, Country, Test, Worker from mirrors_qa_backend.enums import StatusEnum @@ -22,8 +25,8 @@ def dbsession() -> Generator[OrmSession, None, None]: with Session.begin() as session: # Ensure we are starting with an empty database engine = session.get_bind() - models.Base.metadata.drop_all(bind=engine) - models.Base.metadata.create_all(bind=engine) + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) yield session session.rollback() @@ -32,9 +35,9 @@ def dbsession() -> Generator[OrmSession, None, None]: def data_gen(faker: Faker) -> Faker: """Adds additional providers to faker. - Registers test_country and test_status as providers. + Registers test_country_code and test_status as providers. data_gen.test_status() returns a status. - data_gen.test_country() returns a country. + data_gen.test_country_code() returns a country code. All other providers from Faker can be used accordingly. """ test_status_provider = DynamicProvider( @@ -42,11 +45,11 @@ def data_gen(faker: Faker) -> Faker: elements=list(StatusEnum), ) test_country_provider = DynamicProvider( - provider_name="test_country", + provider_name="test_country_code", elements=[ - "Nigeria", - "Canada", - "Brazil", + "ng", + "fr", + "us", ], ) faker.add_provider(test_status_provider) @@ -58,8 +61,8 @@ def data_gen(faker: Faker) -> Faker: @pytest.fixture def tests( - dbsession: OrmSession, data_gen: Faker, worker: models.Worker, request: Any -) -> list[models.Test]: + dbsession: OrmSession, data_gen: Faker, worker: Worker, request: Any +) -> list[Test]: """Adds tests to the database using the num_test mark.""" mark = request.node.get_closest_marker("num_tests") if mark and len(mark.args) > 0: @@ -68,17 +71,28 @@ def tests( num_tests = 10 status = mark.kwargs.get("status", None) - country = mark.kwargs.get("country", None) + country_code = mark.kwargs.get("country_code", None) - tests = [ - models.Test( - status=status if status else data_gen.test_status(), - country=country if country else data_gen.test_country(), + for _ in range(num_tests): + test = Test(status=status if status else data_gen.test_status()) + selected_country_code = ( + country_code if country_code else data_gen.test_country_code() ) - for _ in range(num_tests) - ] - worker.tests = tests - dbsession.add_all(tests) + if country := get_country_or_none(dbsession, selected_country_code): + test.country = country + else: + country = Country( + code=selected_country_code.lower(), + name=pycountry.countries.get( + alpha_2=selected_country_code + ).name, # pyright: ignore [reportOptionalMemberAccess] + ) + dbsession.add(country) + test.country = country + + test.worker = worker + dbsession.add(test) + dbsession.flush() return worker.tests @@ -94,14 +108,23 @@ def public_key(private_key: RSAPrivateKey) -> RSAPublicKey: return private_key.public_key() +@pytest.fixture(scope="session") +def private_key_bytes(private_key: RSAPrivateKey) -> bytes: + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + @pytest.fixture -def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> models.Worker: +def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> Worker: pubkey_pkcs8 = public_key.public_bytes( serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo, ).decode(encoding="ascii") - worker = models.Worker( + worker = Worker( id="test", pubkey_fingerprint=paramiko.RSAKey(key=public_key).fingerprint, # type: ignore pubkey_pkcs8=pubkey_pkcs8, @@ -111,7 +134,7 @@ def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> models.Worker: @pytest.fixture -def auth_message(worker: models.Worker) -> str: +def auth_message(worker: Worker) -> str: return f"{worker.id}:{datetime.datetime.now(datetime.UTC).isoformat()}" diff --git a/backend/tests/db/test_mirrors.py b/backend/tests/db/test_mirrors.py index 11c2b09..62f138d 100644 --- a/backend/tests/db/test_mirrors.py +++ b/backend/tests/db/test_mirrors.py @@ -2,9 +2,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend import db, schemas, serializer -from mirrors_qa_backend.db import mirrors, models +from mirrors_qa_backend import schemas +from mirrors_qa_backend.db import count_from_stmt, models from mirrors_qa_backend.db.exceptions import EmptyMirrorsError +from mirrors_qa_backend.db.mirrors import create_mirrors, create_or_update_mirror_status +from mirrors_qa_backend.serializer import serialize_mirror @pytest.fixture(scope="session") @@ -29,7 +31,7 @@ def db_mirror() -> models.Mirror: @pytest.fixture(scope="session") def schema_mirror(db_mirror: models.Mirror) -> schemas.Mirror: - return serializer.serialize_mirror(db_mirror) + return serialize_mirror(db_mirror) @pytest.fixture(scope="session") @@ -55,20 +57,20 @@ def new_schema_mirror() -> schemas.Mirror: def test_db_empty(dbsession: OrmSession): - assert db.count_from_stmt(dbsession, select(models.Country)) == 0 + assert count_from_stmt(dbsession, select(models.Country)) == 0 def test_create_no_mirrors(dbsession: OrmSession): - assert mirrors.create_mirrors(dbsession, []) == 0 + assert create_mirrors(dbsession, []) == 0 def test_create_mirrors(dbsession: OrmSession, schema_mirror: schemas.Mirror): - assert mirrors.create_mirrors(dbsession, [schema_mirror]) == 1 + assert create_mirrors(dbsession, [schema_mirror]) == 1 def test_raises_empty_mirrors_error(dbsession: OrmSession): with pytest.raises(EmptyMirrorsError): - mirrors.create_or_update_status(dbsession, []) + create_or_update_mirror_status(dbsession, []) def test_register_new_mirror( @@ -78,7 +80,7 @@ def test_register_new_mirror( new_schema_mirror: schemas.Mirror, ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status( + result = create_or_update_mirror_status( dbsession, [schema_mirror, new_schema_mirror] ) assert result.nb_mirrors_added == 1 @@ -90,7 +92,7 @@ def test_disable_old_mirror( new_schema_mirror: schemas.Mirror, ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status(dbsession, [new_schema_mirror]) + result = create_or_update_mirror_status(dbsession, [new_schema_mirror]) assert result.nb_mirrors_disabled == 1 @@ -98,7 +100,7 @@ def test_no_mirrors_disabled( dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status(dbsession, [schema_mirror]) + result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_disabled == 0 @@ -106,7 +108,7 @@ def test_no_mirrors_added( dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror ): dbsession.add(db_mirror) - result = mirrors.create_or_update_status(dbsession, [schema_mirror]) + result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_added == 0 @@ -132,8 +134,8 @@ def test_re_enable_existing_mirror( dbsession.add(db_mirror) # Update the status of the mirror - schema_mirror = serializer.serialize_mirror(db_mirror) + schema_mirror = serialize_mirror(db_mirror) schema_mirror.enabled = True - result = mirrors.create_or_update_status(dbsession, [schema_mirror]) + result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_added == 1 diff --git a/backend/tests/db/test_tests.py b/backend/tests/db/test_tests.py index a7ed9e3..4adfbc5 100644 --- a/backend/tests/db/test_tests.py +++ b/backend/tests/db/test_tests.py @@ -6,20 +6,26 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.db import models -from mirrors_qa_backend.db import tests as db_tests +from mirrors_qa_backend.db.tests import ( + create_or_update_test, + expire_tests, + filter_test, + get_test, + list_tests, +) from mirrors_qa_backend.enums import StatusEnum @pytest.mark.num_tests(1) def test_get_test(dbsession: OrmSession, tests: list[models.Test]): test = tests[0] - result = db_tests.get_test(dbsession, test.id) + result = get_test(dbsession, test.id) assert result is not None assert result.id == test.id @pytest.mark.parametrize( - ["worker_id", "country", "statuses", "expected"], + ["worker_id", "country_code", "statuses", "expected"], [ (None, None, None, True), ("worker_id", None, None, False), @@ -32,14 +38,14 @@ def test_basic_filter( *, dbsession: OrmSession, worker_id: str | None, - country: str | None, + country_code: str | None, statuses: list[StatusEnum] | None, expected: bool, ): - test = db_tests.create_or_update_test(dbsession, status=StatusEnum.PENDING) + test = create_or_update_test(dbsession, status=StatusEnum.PENDING) assert ( - db_tests.filter_test( - test, worker_id=worker_id, country=country, statuses=statuses + filter_test( + test, worker_id=worker_id, country_code=country_code, statuses=statuses ) == expected ) @@ -47,11 +53,11 @@ def test_basic_filter( @pytest.mark.num_tests @pytest.mark.parametrize( - ["worker_id", "country", "statuses"], + ["worker_id", "country_code", "statuses"], [ (None, None, None), - (None, "Nigeria", None), - (None, "Nigeria", [StatusEnum.PENDING]), + (None, "ng", None), + (None, "ng", [StatusEnum.PENDING]), (None, None, [StatusEnum.PENDING, StatusEnum.MISSED]), ], ) @@ -59,18 +65,18 @@ def test_list_tests( dbsession: OrmSession, tests: list[models.Test], worker_id: str | None, - country: str | None, + country_code: str | None, statuses: list[StatusEnum] | None, ): filtered_tests = [ test for test in tests - if db_tests.filter_test( - test, worker_id=worker_id, country=country, statuses=statuses + if filter_test( + test, worker_id=worker_id, country_code=country_code, statuses=statuses ) ] - result = db_tests.list_tests( - dbsession, worker_id=worker_id, country=country, statuses=statuses + result = list_tests( + dbsession, worker_id=worker_id, country_code=country_code, statuses=statuses ) assert len(filtered_tests) == result.nb_tests @@ -84,7 +90,6 @@ def test_update_test(dbsession: OrmSession, tests: list[models.Test], data_gen: speed = download_size / duration update_values = { "status": data_gen.test_status(), - "country": data_gen.test_country(), "download_size": download_size, "duration": duration, "speed": speed, @@ -92,7 +97,7 @@ def test_update_test(dbsession: OrmSession, tests: list[models.Test], data_gen: "started_on": data_gen.date_time(datetime.UTC), "latency": latency, } - updated_test = db_tests.create_or_update_test(dbsession, test_id, **update_values) # type: ignore + updated_test = create_or_update_test(dbsession, test_id, **update_values) # type: ignore for key, value in update_values.items(): if hasattr(updated_test, key): assert getattr(updated_test, key) == value @@ -115,6 +120,6 @@ def test_expire_tests( for test in tests: assert test.status == StatusEnum.PENDING - db_tests.expire_tests(dbsession, interval) + expire_tests(dbsession, interval) for test in tests: assert test.status == expected_status diff --git a/backend/tests/db/test_worker.py b/backend/tests/db/test_worker.py index 5953ebd..6fe99b7 100644 --- a/backend/tests/db/test_worker.py +++ b/backend/tests/db/test_worker.py @@ -2,30 +2,34 @@ from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend.db import models, worker +from mirrors_qa_backend.db.models import Country +from mirrors_qa_backend.db.worker import create_worker -def test_create_worker(dbsession: OrmSession, tmp_path: Path): +def test_create_worker(dbsession: OrmSession, tmp_path: Path, private_key_bytes: bytes): worker_id = "test" countries = [ - models.Country(code="ng", name="Nigeria"), - models.Country(code="fr", name="France"), + Country(code="ng", name="Nigeria"), + Country(code="fr", name="France"), ] dbsession.add_all(countries) - private_key_filename = tmp_path / "key.pem" - new_worker = worker.create_worker( + private_key_fpath = tmp_path / "key.pem" + with private_key_fpath.open("wb") as key_file: + key_file.write(private_key_bytes) + + new_worker = create_worker( dbsession, worker_id=worker_id, - countries=[country.name for country in countries], - private_key_filename=private_key_filename, + country_codes=[country.code for country in countries], + private_key_fpath=private_key_fpath, ) assert new_worker.id == worker_id assert new_worker.pubkey_fingerprint != "" assert len(new_worker.countries) == len(countries) assert "BEGIN PUBLIC KEY" in new_worker.pubkey_pkcs8 assert "END PUBLIC KEY" in new_worker.pubkey_pkcs8 - assert private_key_filename.exists() - contents = private_key_filename.read_text() + assert private_key_fpath.exists() + contents = private_key_fpath.read_text() assert "BEGIN PRIVATE KEY" in contents assert "END PRIVATE KEY" in contents diff --git a/backend/tests/routes/test_auth_endpoints.py b/backend/tests/routes/test_auth_endpoints.py index 543b977..04e595e 100644 --- a/backend/tests/routes/test_auth_endpoints.py +++ b/backend/tests/routes/test_auth_endpoints.py @@ -7,14 +7,14 @@ from fastapi.testclient import TestClient from mirrors_qa_backend.cryptography import sign_message -from mirrors_qa_backend.db import models +from mirrors_qa_backend.db.models import Worker @pytest.mark.parametrize( ["datetime_str", "expected_status", "expected_response_contents"], [ ( - datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC).isoformat(), + datetime.datetime.fromtimestamp(0, tz=datetime.UTC).isoformat(), status.HTTP_401_UNAUTHORIZED, [], ), @@ -32,7 +32,7 @@ ) def test_authenticate_worker( client: TestClient, - worker: models.Worker, + worker: Worker, private_key: RSAPrivateKey, datetime_str: str, expected_status: int, diff --git a/dev/docker-compose.yaml b/dev/docker-compose.yaml index bc162d3..eaf7097 100644 --- a/dev/docker-compose.yaml +++ b/dev/docker-compose.yaml @@ -49,7 +49,6 @@ services: container_name: mirrors-qa-scheduler environment: POSTGRES_URI: postgresql+psycopg://mirrors_qa:mirrors_qa@postgresdb:5432/mirrors_qa - JWT_SECRET: DH8kSxcflUVfNRdkEiJJCn2dOOKI3qfw DEBUG: true command: mirrors-qa-scheduler networks: