Skip to content

Commit

Permalink
split backend settings, use fully qualified import names
Browse files Browse the repository at this point in the history
  • Loading branch information
elfkuzco committed Jun 20, 2024
1 parent c80736d commit f195d64
Show file tree
Hide file tree
Showing 30 changed files with 359 additions and 245 deletions.
4 changes: 3 additions & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ ignore = [
"S603",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Ignore warnings on missing timezone info
"DTZ005", "DTZ001", "DTZ006",
]
unfixable = [
# Don't touch unused imports
Expand All @@ -216,7 +218,7 @@ testpaths = ["tests"]
pythonpath = [".", "src"]
addopts = "--strict-markers"
markers = [
"num_tests(num=10, *, status=..., country=...): create num tests in the database using status and/or country. Random data is chosen for country or status if either is not set",
"num_tests(num=10, *, status=..., country_code=...): create num tests in the database using status and/or country_code. Random data is chosen for country_code or status if either is not set",
]

[tool.coverage.paths]
Expand Down
33 changes: 7 additions & 26 deletions backend/src/mirrors_qa_backend/cryptography.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# pyright: strict, reportGeneralTypeIssues=false
import datetime
from pathlib import Path

import jwt
import paramiko
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey

from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError
from mirrors_qa_backend.settings import Settings


def verify_signed_message(public_key: bytes, signature: bytes, message: bytes) -> bool:
Expand Down Expand Up @@ -45,16 +43,11 @@ def sign_message(private_key: RSAPrivateKey, message: bytes) -> bytes:
)


def generate_private_key(key_size: int = 2048) -> RSAPrivateKey:
return rsa.generate_private_key(public_exponent=65537, key_size=key_size)


def serialize_private_key(private_key: RSAPrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
def load_private_key_from_path(private_key_fpath: Path) -> RSAPrivateKey:
with private_key_fpath.open("rb") as key_file:
return serialization.load_pem_private_key(
key_file.read(), password=None
) # pyright: ignore[reportReturnType]


def generate_public_key(private_key: RSAPrivateKey) -> RSAPublicKey:
Expand All @@ -73,15 +66,3 @@ def get_public_key_fingerprint(public_key: RSAPublicKey) -> str:
return paramiko.RSAKey(
key=public_key
).fingerprint # pyright: ignore[reportUnknownMemberType, UnknownVariableType]


def generate_access_token(worker_id: str) -> str:
issue_time = datetime.datetime.now(datetime.UTC)
expire_time = issue_time + datetime.timedelta(hours=Settings.TOKEN_EXPIRY)
payload = {
"iss": "mirrors-qa-backend", # issuer
"exp": expire_time.timestamp(), # expiration time
"iat": issue_time.timestamp(), # issued at
"subject": worker_id,
}
return jwt.encode(payload, key=Settings.JWT_SECRET, algorithm="HS256")
7 changes: 4 additions & 3 deletions backend/src/mirrors_qa_backend/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from sqlalchemy.orm import sessionmaker

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import mirrors, models
from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status
from mirrors_qa_backend.extract import get_current_mirrors
from mirrors_qa_backend.settings import Settings

Expand Down Expand Up @@ -46,14 +47,14 @@ def initialize_mirrors() -> None:
if not current_mirrors:
logger.info(f"No mirrors were found on {Settings.MIRRORS_URL!r}")
return
result = mirrors.create_or_update_status(session, current_mirrors)
result = create_or_update_mirror_status(session, current_mirrors)
logger.info(
f"Registered {result.nb_mirrors_added} mirrors "
f"from {Settings.MIRRORS_URL!r}"
)
else:
logger.info(f"Found {nb_mirrors} mirrors in database.")
result = mirrors.create_or_update_status(session, current_mirrors)
result = create_or_update_mirror_status(session, current_mirrors)
logger.info(
f"Added {result.nb_mirrors_added} mirrors. "
f"Disabled {result.nb_mirrors_disabled} mirrors."
Expand Down
14 changes: 9 additions & 5 deletions backend/src/mirrors_qa_backend/db/country.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.models import Country


def get_countries_by_name(session: OrmSession, *countries: str) -> list[models.Country]:
def get_countries(session: OrmSession, *country_codes: str) -> list[Country]:
return list(
session.scalars(
select(models.Country).where(models.Country.name.in_(countries))
).all()
session.scalars(select(Country).where(Country.code.in_(country_codes))).all()
)


def get_country_or_none(session: OrmSession, country_code: str) -> Country | None:
return session.scalars(
select(Country).where(Country.code == country_code)
).one_or_none()
2 changes: 1 addition & 1 deletion backend/src/mirrors_qa_backend/db/mirrors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int:
return nb_created


def create_or_update_status(
def create_or_update_mirror_status(
session: OrmSession, mirrors: list[schemas.Mirror]
) -> MirrorsUpdateResult:
"""Updates the status of mirrors in the database and creates any new mirrors.
Expand Down
10 changes: 9 additions & 1 deletion backend/src/mirrors_qa_backend/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class Country(Base):
cascade="all, delete-orphan",
)

tests: Mapped[list[Test]] = relationship(back_populates="country", init=False)

__table_args__ = (UniqueConstraint("name", "code"),)


Expand Down Expand Up @@ -131,7 +133,11 @@ class Test(Base):
ip_address: Mapped[IPv4Address | None] = mapped_column(default=None)
# autonomous system based on IP
asn: Mapped[str | None] = mapped_column(default=None)
country: Mapped[str | None] = mapped_column(default=None) # country based on IP
country_code: Mapped[str | None] = mapped_column(
ForeignKey("country.code"),
init=False,
default=None,
)
location: Mapped[str | None] = mapped_column(default=None) # city based on IP
latency: Mapped[int | None] = mapped_column(default=None) # milliseconds
download_size: Mapped[int | None] = mapped_column(default=None) # bytes
Expand All @@ -142,3 +148,5 @@ class Test(Base):
)

worker: Mapped[Worker | None] = relationship(back_populates="tests", init=False)

country: Mapped[Country | None] = relationship(back_populates="tests", init=False)
22 changes: 12 additions & 10 deletions backend/src/mirrors_qa_backend/db/tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# ruff: noqa: DTZ005, DTZ001
import datetime
from dataclasses import dataclass
from ipaddress import IPv4Address
Expand All @@ -8,6 +7,7 @@
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.country import get_country_or_none
from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError
from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum
from mirrors_qa_backend.settings import Settings
Expand All @@ -25,7 +25,7 @@ def filter_test(
test: models.Test,
*,
worker_id: str | None = None,
country: str | None = None,
country_code: str | None = None,
statuses: list[StatusEnum] | None = None,
) -> bool:
"""Checks if a test has the same attribute as the provided attribute.
Expand All @@ -35,7 +35,7 @@ def filter_test(
"""
if worker_id is not None and test.worker_id != worker_id:
return False
if country is not None and test.country != country:
if country_code is not None and test.country_code != country_code:
return False
if statuses is not None and test.status not in statuses:
return False
Expand All @@ -52,7 +52,7 @@ def list_tests(
session: OrmSession,
*,
worker_id: str | None = None,
country: str | None = None,
country_code: str | None = None,
statuses: list[StatusEnum] | None = None,
page_num: int = 1,
page_size: int = Settings.MAX_PAGE_SIZE,
Expand Down Expand Up @@ -88,7 +88,7 @@ def list_tests(
select(func.count().over().label("total_records"), models.Test)
.where(
(models.Test.worker_id == worker_id) | (worker_id is None),
(models.Test.country == country) | (country is None),
(models.Test.country_code == country_code) | (country_code is None),
(models.Test.status.in_(statuses)),
)
.order_by(*order_by)
Expand All @@ -114,7 +114,7 @@ def create_or_update_test(
error: str | None = None,
ip_address: IPv4Address | None = None,
asn: str | None = None,
country: str | None = None,
country_code: str | None = None,
location: str | None = None,
latency: int | None = None,
download_size: int | None = None,
Expand All @@ -136,7 +136,9 @@ def create_or_update_test(
test.error = error if error else test.error
test.ip_address = ip_address if ip_address else test.ip_address
test.asn = asn if asn else test.asn
test.country = country if country else test.country
test.country = (
get_country_or_none(session, country_code) if country_code else test.country
)
test.location = location if location else test.location
test.latency = latency if latency else test.latency
test.download_size = download_size if download_size else test.download_size
Expand All @@ -158,7 +160,7 @@ def create_test(
error: str | None = None,
ip_address: IPv4Address | None = None,
asn: str | None = None,
country: str | None = None,
country_code: str | None = None,
location: str | None = None,
latency: int | None = None,
download_size: int | None = None,
Expand All @@ -174,7 +176,7 @@ def create_test(
error=error,
ip_address=ip_address,
asn=asn,
country=country,
country_code=country_code,
location=location,
latency=latency,
download_size=download_size,
Expand All @@ -189,7 +191,7 @@ def expire_tests(
) -> list[models.Test]:
"""Change the status of PENDING tests created before the interval to MISSED"""
end = datetime.datetime.now() - interval
begin = datetime.datetime(1970, 1, 1)
begin = datetime.datetime.fromtimestamp(0)
return list(
session.scalars(
update(models.Test)
Expand Down
70 changes: 31 additions & 39 deletions backend/src/mirrors_qa_backend/db/worker.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,51 @@
# ruff: noqa: DTZ005, DTZ001
import datetime
from pathlib import Path

from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend import cryptography
from mirrors_qa_backend.db import country, models
from mirrors_qa_backend.cryptography import (
generate_public_key,
get_public_key_fingerprint,
load_private_key_from_path,
serialize_public_key,
)
from mirrors_qa_backend.db.country import get_countries
from mirrors_qa_backend.db.exceptions import DuplicatePrimaryKeyError
from mirrors_qa_backend.db.models import Worker
from mirrors_qa_backend.exceptions import PEMPrivateKeyLoadError


def get_worker(session: OrmSession, worker_id: str) -> models.Worker | None:
return session.scalars(
select(models.Worker).where(models.Worker.id == worker_id)
).one_or_none()
def get_worker(session: OrmSession, worker_id: str) -> Worker | None:
return session.scalars(select(Worker).where(Worker.id == worker_id)).one_or_none()


def create_worker(
session: OrmSession,
worker_id: str,
countries: list[str],
private_key_filename: str | Path | None = None,
) -> models.Worker:
"""Creates a worker and writes private key contents to private_key_filename.
If no private_key_filename is provided, defaults to {worker_id}.pem.
"""
country_codes: list[str],
private_key_fpath: Path,
) -> Worker:
"""Creates a worker using RSA private key."""
if get_worker(session, worker_id) is not None:
raise DuplicatePrimaryKeyError(
f"A worker with id {worker_id!r} already exists."
)

if private_key_filename is None:
private_key_filename = f"{worker_id}.pem"

private_key = cryptography.generate_private_key()
public_key = cryptography.generate_public_key(private_key)
public_key_pkcs8 = cryptography.serialize_public_key(public_key).decode(
encoding="ascii"
)
with open(private_key_filename, "wb") as fp:
fp.write(cryptography.serialize_private_key(private_key))

worker = models.Worker(
try:
private_key = load_private_key_from_path(private_key_fpath)
except Exception as exc:
raise PEMPrivateKeyLoadError("unable to load private key from file") from exc

public_key = generate_public_key(private_key)
public_key_pkcs8 = serialize_public_key(public_key).decode(encoding="ascii")
worker = Worker(
id=worker_id,
pubkey_pkcs8=public_key_pkcs8,
pubkey_fingerprint=cryptography.get_public_key_fingerprint(public_key),
pubkey_fingerprint=get_public_key_fingerprint(public_key),
)
session.add(worker)

for db_country in country.get_countries_by_name(session, *countries):
for db_country in get_countries(session, *country_codes):
db_country.worker_id = worker_id
session.add(db_country)

Expand All @@ -58,36 +54,32 @@ def create_worker(

def get_workers_last_seen_in_range(
session: OrmSession, begin: datetime.datetime, end: datetime.datetime
) -> list[models.Worker]:
) -> list[Worker]:
"""Get workers whose last_seen_on falls between begin and end dates"""
return list(
session.scalars(
select(models.Worker).where(
models.Worker.last_seen_on.between(begin, end),
select(Worker).where(
Worker.last_seen_on.between(begin, end),
)
).all()
)


def get_idle_workers(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Worker]:
def get_idle_workers(session: OrmSession, interval: datetime.timedelta) -> list[Worker]:
end = datetime.datetime.now() - interval
begin = datetime.datetime(1970, 1, 1)
return get_workers_last_seen_in_range(session, begin, end)


def get_active_workers(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Worker]:
) -> list[Worker]:
end = datetime.datetime.now()
begin = end - interval
return get_workers_last_seen_in_range(session, begin, end)


def update_worker_last_seen(
session: OrmSession, worker: models.Worker
) -> models.Worker:
def update_worker_last_seen(session: OrmSession, worker: Worker) -> Worker:
worker.last_seen_on = datetime.datetime.now()
session.add(worker)
return worker
Loading

0 comments on commit f195d64

Please sign in to comment.