Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set up scheduler to create tasks for idle workers #20

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pycountry==24.6.1",
"cryptography==42.0.8",
"PyJWT==2.8.0",
"paramiko==3.4.0",
]
license = {text = "GPL-3.0-or-later"}
classifiers = [
Expand All @@ -37,6 +38,7 @@ Homepage = "https://github.com/kiwix/mirrors-qa"

[project.scripts]
update-mirrors = "mirrors_qa_backend.entrypoint:main"
mirrors-qa-scheduler = "mirrors_qa_backend.scheduler:main"

[project.optional-dependencies]
scripts = [
Expand All @@ -53,7 +55,6 @@ test = [
"pytest==8.0.0",
"coverage==7.4.1",
"Faker==25.8.0",
"paramiko==3.4.0",
"httpx==0.27.0",
]
dev = [
Expand Down Expand Up @@ -215,7 +216,7 @@ testpaths = ["tests"]
pythonpath = [".", "src"]
addopts = "--strict-markers"
markers = [
"num_tests: number of tests to create in the database (default: 10)",
"num_tests(num=10, *, status=..., country=...): create num tests in the database using status and/or country. Random data is chosen for country or status if either is not set",
rgaudin marked this conversation as resolved.
Show resolved Hide resolved
]

[tool.coverage.paths]
Expand Down
35 changes: 33 additions & 2 deletions backend/src/mirrors_qa_backend/cryptography.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import datetime

import jwt
import paramiko
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey

from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError
from mirrors_qa_backend.settings import Settings
Expand Down Expand Up @@ -44,6 +45,36 @@ def sign_message(private_key: RSAPrivateKey, message: bytes) -> bytes:
)


def generate_private_key(key_size: int = 2048) -> RSAPrivateKey:
return rsa.generate_private_key(public_exponent=65537, key_size=key_size)


def serialize_private_key(private_key: RSAPrivateKey) -> bytes:
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)


def generate_public_key(private_key: RSAPrivateKey) -> RSAPublicKey:
return private_key.public_key()


def serialize_public_key(public_key: RSAPublicKey) -> bytes:
return public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)


def get_public_key_fingerprint(public_key: RSAPublicKey) -> str:
"""Compute the SHA256 fingerprint of the public key"""
return paramiko.RSAKey(
key=public_key
).fingerprint # pyright: ignore[reportUnknownMemberType, UnknownVariableType]


def generate_access_token(worker_id: str) -> str:
issue_time = datetime.datetime.now(datetime.UTC)
expire_time = issue_time + datetime.timedelta(hours=Settings.TOKEN_EXPIRY)
Expand Down
12 changes: 12 additions & 0 deletions backend/src/mirrors_qa_backend/db/country.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models


def get_countries_by_name(session: OrmSession, *countries: str) -> list[models.Country]:
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
return list(
session.scalars(
select(models.Country).where(models.Country.name.in_(countries))
).all()
)
4 changes: 4 additions & 0 deletions backend/src/mirrors_qa_backend/db/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ def __init__(self, message: str, *args: object) -> None:

class EmptyMirrorsError(Exception):
"""An empty list was used to update the mirrors in the database."""


class DuplicatePrimaryKeyError(Exception):
"""A database record with the same primary key exists."""
57 changes: 56 additions & 1 deletion backend/src/mirrors_qa_backend/db/tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# ruff: noqa: DTZ005, DTZ001
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
import datetime
from dataclasses import dataclass
from ipaddress import IPv4Address
from uuid import UUID

from sqlalchemy import UnaryExpression, asc, desc, func, select
from sqlalchemy import UnaryExpression, asc, desc, func, select, update
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
Expand Down Expand Up @@ -144,5 +145,59 @@ def create_or_update_test(
test.started_on = started_on if started_on else test.started_on

session.add(test)
session.flush()

return test


def create_test(
session: OrmSession,
*,
worker_id: str | None = None,
status: StatusEnum = StatusEnum.PENDING,
error: str | None = None,
ip_address: IPv4Address | None = None,
asn: str | None = None,
country: str | None = None,
location: str | None = None,
latency: int | None = None,
download_size: int | None = None,
duration: int | None = None,
speed: float | None = None,
started_on: datetime.datetime | None = None,
) -> models.Test:
return create_or_update_test(
session,
test_id=None,
worker_id=worker_id,
status=status,
error=error,
ip_address=ip_address,
asn=asn,
country=country,
location=location,
latency=latency,
download_size=download_size,
duration=duration,
speed=speed,
started_on=started_on,
)


def expire_tests(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Test]:
"""Change the status of PENDING tests created before the interval to MISSED"""
end = datetime.datetime.now() - interval
begin = datetime.datetime(1970, 1, 1)
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
return list(
session.scalars(
update(models.Test)
.where(
models.Test.requested_on.between(begin, end),
models.Test.status == StatusEnum.PENDING,
)
.values(status=StatusEnum.MISSED)
.returning(models.Test)
).all()
)
85 changes: 84 additions & 1 deletion backend/src/mirrors_qa_backend/db/worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,93 @@
# ruff: noqa: DTZ005, DTZ001
import datetime
from pathlib import Path

from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
from mirrors_qa_backend import cryptography
from mirrors_qa_backend.db import country, models
from mirrors_qa_backend.db.exceptions import DuplicatePrimaryKeyError


def get_worker(session: OrmSession, worker_id: str) -> models.Worker | None:
return session.scalars(
select(models.Worker).where(models.Worker.id == worker_id)
).one_or_none()


def create_worker(
session: OrmSession,
worker_id: str,
countries: list[str],
private_key_filename: str | Path | None = None,
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
) -> models.Worker:
"""Creates a worker and writes private key contents to private_key_filename.

If no private_key_filename is provided, defaults to {worker_id}.pem.
"""
if get_worker(session, worker_id) is not None:
raise DuplicatePrimaryKeyError(
f"A worker with id {worker_id!r} already exists."
)

if private_key_filename is None:
private_key_filename = f"{worker_id}.pem"

private_key = cryptography.generate_private_key()
public_key = cryptography.generate_public_key(private_key)
public_key_pkcs8 = cryptography.serialize_public_key(public_key).decode(
encoding="ascii"
)
with open(private_key_filename, "wb") as fp:
fp.write(cryptography.serialize_private_key(private_key))

worker = models.Worker(
id=worker_id,
pubkey_pkcs8=public_key_pkcs8,
pubkey_fingerprint=cryptography.get_public_key_fingerprint(public_key),
)
session.add(worker)

for db_country in country.get_countries_by_name(session, *countries):
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
db_country.worker_id = worker_id
session.add(db_country)

return worker


def get_workers_last_seen_in_range(
session: OrmSession, begin: datetime.datetime, end: datetime.datetime
) -> list[models.Worker]:
"""Get workers whose last_seen_on falls between begin and end dates"""
return list(
session.scalars(
select(models.Worker).where(
models.Worker.last_seen_on.between(begin, end),
)
).all()
)


def get_idle_workers(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Worker]:
end = datetime.datetime.now() - interval
begin = datetime.datetime(1970, 1, 1)
return get_workers_last_seen_in_range(session, begin, end)


def get_active_workers(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Worker]:
end = datetime.datetime.now()
begin = end - interval
return get_workers_last_seen_in_range(session, begin, end)


def update_worker_last_seen(
session: OrmSession, worker: models.Worker
) -> models.Worker:
worker.last_seen_on = datetime.datetime.now()
session.add(worker)
return worker
4 changes: 3 additions & 1 deletion backend/src/mirrors_qa_backend/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def is_country_row(tag: Tag) -> bool:
resp = requests.get(Settings.MIRRORS_URL, timeout=Settings.REQUESTS_TIMEOUT)
resp.raise_for_status()
except requests.RequestException as exc:
raise MirrorsRequestError from exc
raise MirrorsRequestError(
"network error while fetching mirrors from url"
) from exc

soup = BeautifulSoup(resp.text, features="html.parser")
body = soup.find("tbody")
Expand Down
8 changes: 4 additions & 4 deletions backend/src/mirrors_qa_backend/routes/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import status as status_codes

from mirrors_qa_backend import schemas, serializer
from mirrors_qa_backend.db import tests
from mirrors_qa_backend.db import tests, worker
from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum
from mirrors_qa_backend.routes.dependencies import (
CurrentWorker,
Expand Down Expand Up @@ -78,7 +78,7 @@ def get_test(test: RetrievedTest) -> schemas.Test:
)
def update_test(
session: DbSession,
worker: CurrentWorker,
current_worker: CurrentWorker,
test: RetrievedTest,
update: schemas.UpdateTestModel,
) -> schemas.Test:
Expand All @@ -87,7 +87,7 @@ def update_test(
updated_test = tests.create_or_update_test(
session,
test_id=test.id,
worker_id=worker.id,
worker_id=current_worker.id,
status=body.status,
error=body.error,
ip_address=body.ip_address,
Expand All @@ -99,5 +99,5 @@ def update_test(
duration=body.duration,
speed=body.speed,
)

worker.update_worker_last_seen(session, current_worker)
return serializer.serialize_test(updated_test)
75 changes: 75 additions & 0 deletions backend/src/mirrors_qa_backend/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import datetime
import time

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session, tests, worker
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
from mirrors_qa_backend.enums import StatusEnum
from mirrors_qa_backend.settings import Settings


def main():
while True:
with Session.begin() as session:
# expire tesst whose results have not been reported
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
expired_tests = tests.expire_tests(
session,
interval=datetime.timedelta(hours=Settings.EXPIRE_TEST_INTERVAL),
)
for expired_test in expired_tests:
logger.info(
f"Expired test {expired_test.id}, "
f"country: {expired_test.country}, "
f"worker: {expired_test.worker_id}"
)

idle_workers = worker.get_idle_workers(
session,
interval=datetime.timedelta(hours=Settings.IDLE_WORKER_INTERVAL),
)
if not idle_workers:
logger.info("No idle workers found.")

# Create tests for the countries the worker is responsible for..
for idle_worker in idle_workers:
if not idle_worker.countries:
logger.info(
f"No countries registered for idle worker {idle_worker.id}"
)
continue
for country in idle_worker.countries:
# While we have expired "unreported" tests, it is possible that
# a test for a country might still be PENDING as the interval
# for expiration and that of the scheduler might overlap.
# In such scenarios, we skip creating a test for that country.
pending_tests = tests.list_tests(
session,
worker_id=idle_worker.id,
statuses=[StatusEnum.PENDING],
country=country.name,
elfkuzco marked this conversation as resolved.
Show resolved Hide resolved
)

if pending_tests.nb_tests:
logger.info(
"Skipping creation of new test entries for "
f"{idle_worker.id} as {pending_tests.nb_tests} "
"tests are still pending."
)
continue

new_test = tests.create_test(
session=session,
worker_id=idle_worker.id,
country=country.name,
status=StatusEnum.PENDING,
)
logger.info(
f"Created new test {new_test.id} for worker "
f"{idle_worker.id} in country {country.name}"
)

sleep_interval = datetime.timedelta(
hours=Settings.SCHEDULER_SLEEP_INTERVAL
).total_seconds()

logger.info(f"Sleeping for {sleep_interval} seconds.")
time.sleep(sleep_interval)
Loading
Loading