From 7b0c1c4615bc2b88de72d03ab24dd811d940a6fb Mon Sep 17 00:00:00 2001 From: Uchechukwu Orji Date: Mon, 10 Jun 2024 15:46:33 +0100 Subject: [PATCH] use mirrors data to update db instead of countries --- backend/pyproject.toml | 3 +- backend/src/mirrors_qa_backend/__main__.py | 6 - backend/src/mirrors_qa_backend/cli.py | 23 -- backend/src/mirrors_qa_backend/db/__init__.py | 27 ++- backend/src/mirrors_qa_backend/db/mirrors.py | 214 ++++++++---------- backend/src/mirrors_qa_backend/db/models.py | 6 +- backend/src/mirrors_qa_backend/entrypoint.py | 31 +++ backend/src/mirrors_qa_backend/exceptions.py | 2 + backend/src/mirrors_qa_backend/extract.py | 57 +++++ .../0c273daa1ab0_set_up_database_models.py | 4 +- backend/src/mirrors_qa_backend/schemas.py | 12 +- backend/src/mirrors_qa_backend/settings.py | 11 +- backend/tests/db/test_mirrors.py | 178 +++++++-------- 13 files changed, 307 insertions(+), 267 deletions(-) delete mode 100644 backend/src/mirrors_qa_backend/__main__.py delete mode 100644 backend/src/mirrors_qa_backend/cli.py create mode 100644 backend/src/mirrors_qa_backend/entrypoint.py create mode 100644 backend/src/mirrors_qa_backend/exceptions.py create mode 100644 backend/src/mirrors_qa_backend/extract.py diff --git a/backend/pyproject.toml b/backend/pyproject.toml index dd9482c..075fb8c 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "psycopg[binary,pool] == 3.1.19", "beautifulsoup4 == 4.12.3", "requests == 2.32.3", + "pycountry == 24.6.1", ] license = {text = "GPL-3.0-or-later"} classifiers = [ @@ -33,7 +34,7 @@ dynamic = ["version"] Homepage = "https://github.com/kiwix/mirrors-qa" [project.scripts] -mirrors-qa-backend = "mirrors_qa_backend.cli:main" +mirrors-qa-backend = "mirrors_qa_backend.entrypoint:main" [project.optional-dependencies] scripts = [ diff --git a/backend/src/mirrors_qa_backend/__main__.py b/backend/src/mirrors_qa_backend/__main__.py deleted file mode 100644 index 5545828..0000000 --- a/backend/src/mirrors_qa_backend/__main__.py +++ /dev/null @@ -1,6 +0,0 @@ -import sys - -if __name__ == "__main__": - from mirrors_qa_backend.cli import main - - sys.exit(main()) diff --git a/backend/src/mirrors_qa_backend/cli.py b/backend/src/mirrors_qa_backend/cli.py deleted file mode 100644 index b8bda05..0000000 --- a/backend/src/mirrors_qa_backend/cli.py +++ /dev/null @@ -1,23 +0,0 @@ -import argparse - -from mirrors_qa_backend import Settings, db -from mirrors_qa_backend.db import mirrors - - -def main(): - parser = argparse.ArgumentParser(prog="mirrors-qa-backend") - parser.add_argument( - "--update-mirrors", - action="store_true", - help=f"Update the list of mirrors from {Settings.mirrors_url}", - ) - - args = parser.parse_args() - - if args.update_mirrors: - with db.Session.begin() as session: - mirrors.update_mirrors(session, mirrors.get_current_mirror_countries()) - - -if __name__ == "__main__": - main() diff --git a/backend/src/mirrors_qa_backend/db/__init__.py b/backend/src/mirrors_qa_backend/db/__init__.py index 694a92f..20fdf15 100644 --- a/backend/src/mirrors_qa_backend/db/__init__.py +++ b/backend/src/mirrors_qa_backend/db/__init__.py @@ -8,6 +8,7 @@ from mirrors_qa_backend import logger from mirrors_qa_backend.db import mirrors, models +from mirrors_qa_backend.extract import get_current_mirrors from mirrors_qa_backend.settings import Settings Session = sessionmaker( @@ -38,16 +39,22 @@ def count_from_stmt(session: OrmSession, stmt: SelectBase) -> int: def initialize_mirrors() -> None: with Session.begin() as session: - count = count_from_stmt(session, select(models.Mirror)) - countries = mirrors.get_current_mirror_countries() - if count == 0: + nb_mirrors = count_from_stmt(session, select(models.Mirror)) + current_mirrors = get_current_mirrors() + if nb_mirrors == 0: logger.info("No mirrors exist in database.") - # update mirrors from https://download.kiwix.org/mirrors.html - if not countries: - logger.info(f"No mirrors were found on {Settings.mirrors_url}") + if not current_mirrors: + logger.info(f"No mirrors were found on {Settings.mirrors_url!r}") return - mirrors.create_mirrors(session, countries) + results = mirrors.update_mirrors(session, current_mirrors) + logger.info( + f"Registered {results.nb_mirrors_added} mirrors " + f"from {Settings.mirrors_url!r}" + ) else: - logger.info(f"Found {count} mirrors in database.") - # Update the list of enabled mirrors - mirrors.update_mirrors(session, countries) + logger.info(f"Found {nb_mirrors} mirrors in database.") + result = mirrors.update_mirrors(session, current_mirrors) + logger.info( + f"Added {result.nb_mirrors_added} mirrors. " + f"Disabled {result.nb_mirrors_disabled} mirrors." + ) diff --git a/backend/src/mirrors_qa_backend/db/mirrors.py b/backend/src/mirrors_qa_backend/db/mirrors.py index db0bf13..33a246c 100644 --- a/backend/src/mirrors_qa_backend/db/mirrors.py +++ b/backend/src/mirrors_qa_backend/db/mirrors.py @@ -1,140 +1,110 @@ -from typing import Any -from urllib.parse import urlsplit +from dataclasses import dataclass -import requests -from bs4 import BeautifulSoup, NavigableString -from bs4.element import Tag from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession from sqlalchemy.orm import selectinload from mirrors_qa_backend import logger, schemas from mirrors_qa_backend.db import models -from mirrors_qa_backend.settings import Settings +from mirrors_qa_backend.exceptions import EmptyMirrorsError -def create_mirrors(session: OrmSession, countries: list[schemas.Country]) -> None: - for country in countries: - c = models.Country(code=country.code, name=country.name) - c.mirrors = [models.Mirror(**m.model_dump()) for m in country.mirrors] - session.add(c) +@dataclass +class UpdateMirrorsResult: + """Represents the results of an update to the list of mirrors in the database""" + nb_mirrors_added: int = 0 + nb_mirrors_disabled: int = 0 -def update_mirrors(session: OrmSession, countries: list[schemas.Country]) -> None: + +def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int: + """ + Given a list of schemas.Mirror, saves all the mirrors + to the database. + Returns the total number of mirrors created. + + Assumes that each mirror does not exist on the database. """ - Updates the status of mirrors in the database. Any mirrors in the database - that do not exist in the current mirrors obtained from `countries` are - marked as disabled. New mirrors are saved accordingly. + total = 0 + for mirror in mirrors: + db_mirror = models.Mirror( + id=mirror.id, + base_url=mirror.base_url, + enabled=mirror.enabled, + region=mirror.region, + asn=mirror.asn, + score=mirror.score, + latitude=mirror.latitude, + longitude=mirror.longitude, + country_only=mirror.country_only, + region_only=mirror.country_only, + as_only=mirror.as_only, + other_countries=mirror.other_countries, + ) + # Ensure the country exists for the mirror + country = session.scalars( + select(models.Country).where(models.Country.code == mirror.country.code) + ).one_or_none() + + if country is None: + country = models.Country(code=mirror.country.code, name=mirror.country.name) + session.add(country) + + db_mirror.country = country + session.add(db_mirror) + logger.debug( + f"Registered new mirror: {db_mirror.id!r} for country: {country.name!r}" + ) + total += 1 + return total + + +def update_mirrors( + session: OrmSession, mirrors: list[schemas.Mirror] +) -> UpdateMirrorsResult: """ + Given a list of current_mirrors, compares the list with the existing mirrors + in the database and disables mirrors in the database that are not in the list. + New mirrors from the list that are not in the database are created in the + database. + + Returns UpdateMirrorsResult showing the total mirrors added and updated. + """ + result = UpdateMirrorsResult() # If there are no countries, disable all mirrors - if not countries: - for mirror in session.scalars(select(models.Mirror)).all(): - mirror.enabled = False - session.add(mirror) - return - - query = select(models.Country).options(selectinload(models.Country.mirrors)) - # Map the country codes to each country from the database. To be used - # to compare against the list of current countries - db_countries: dict[str, models.Country] = { - country.code: country for country in session.scalars(query).all() + if not mirrors: + raise EmptyMirrorsError("mirrors list must not be empty") + + # Map the id (hostname) of each mirror from the mirrors list for comparison + # against the id of mirrors from the database. To be used in determining + # if this mirror is a new mirror, in which case it should be added + current_mirrors: dict[str, schemas.Mirror] = { + mirror.id: mirror for mirror in mirrors } - # Map the country codes to each country from the current list of coutnries. - # To be used in determining if a country is to be newly registered - current_countries: dict[str, schemas.Country] = { - country.code: country for country in countries + + # Map the id (hostname) of each mirror from the database for comparison + # against the id of mirrors in current_mirrors. To be used in determining + # if this mirror should be disabled + query = select(models.Mirror).options(selectinload(models.Mirror.country)) + db_mirrors: dict[str, models.Mirror] = { + mirror.id: mirror for mirror in session.scalars(query).all() } - for country_code, country in current_countries.items(): - if country_code not in db_countries: - # Register all of the country's mirrors as the country is - # a new country - logger.debug("Registering new mirrors for {country_code!r}") - c = models.Country(code=country.code, name=country.name) - c.mirrors = [models.Mirror(**m.model_dump()) for m in country.mirrors] - session.add(c) - - for code, db_country in db_countries.items(): - if code in current_countries: - # Even though the db_country is "current", ensure it's mirrors - # are in sync with the current mirrors - current_mirrors: dict[str, schemas.Mirror] = { - m.id: m for m in current_countries[code].mirrors - } - db_mirrors: dict[str, models.Mirror] = {m.id: m for m in db_country.mirrors} - - for db_mirror in db_mirrors.values(): - if db_mirror.id not in current_mirrors: - logger.debug(f"Disabling mirror {db_mirror.id!r}") - db_mirror.enabled = False - session.add(db_mirror) - - for mirror_id, mirror in current_mirrors.items(): - if mirror_id not in db_mirrors: - logger.debug( - f"Registering new mirror {mirror.id!r} for " - "country: {db_country.name!r}" - ) - db_country.mirrors.append(models.Mirror(**mirror.model_dump())) - session.add(db_country) - else: - # disable all of the country's mirrors as they have been removed - for db_mirror in db_country.mirrors: - logger.debug(f"Disabling mirror {db_mirror.id!r}") - db_mirror.enabled = False - session.add(db_mirror) - - -def get_current_mirror_countries() -> list[schemas.Country]: - def find_country_rows(tag: Tag) -> bool: - """ - Filters out table rows that do not contain mirror - data from the table body. - """ - return tag.name == "tr" and tag.findChild("td", class_="newregion") is None - - r = requests.get(Settings.mirrors_url, timeout=Settings.requests_timeout) - r.raise_for_status() - - soup = BeautifulSoup(r.text, features="html.parser") - body = soup.find("tbody") - - if body is None or isinstance(body, NavigableString): - raise ValueError - # Given a country might have more than one mirror, set up a dictionary - # of country_code to the country's data. If it is the first time we - # are seeing the country, we save it along with its mirror, else, - # we simply update its mirrors list. - countries: dict[str, schemas.Country] = {} - rows = body.find_all(find_country_rows) - for row in rows: - country_name = row.find("img").next_sibling.text.strip() - if country_name in Settings.mirrors_exclusion_list: - continue - country_code = row.find("img")["alt"] - base_url = row.find("a", string="HTTP")["href"] - hostname: Any = urlsplit( - base_url - ).netloc # pyright: ignore [reportUnknownMemberType] - - if country_code not in countries: - countries[country_code] = schemas.Country( - code=country_code, - name=country_name, - mirrors=[ - schemas.Mirror( - id=hostname, - base_url=base_url, - enabled=True, - ) - ], - ) - else: - countries[country_code].mirrors.append( - schemas.Mirror( - id=hostname, - base_url=base_url, - enabled=True, - ) + # Create any mirror that doesn't exist on the database + for mirror_id, mirror in current_mirrors.items(): + if mirror_id not in db_mirrors: + # Create the mirror as it doesn't exists on the database. + result.nb_mirrors_added += create_mirrors(session, [mirror]) + + # Disable any mirror in the database that doesn't exist on the current + # list of mirrors + for db_mirror_id, db_mirror in db_mirrors.items(): + if db_mirror_id not in current_mirrors: + logger.debug( + f"Disabling mirror: {db_mirror.id!r} for " + f"country: {db_mirror.country.name!r}" ) - return list(countries.values()) + db_mirror.enabled = False + session.add(db_mirror) + result.nb_mirrors_disabled += 1 + return result diff --git a/backend/src/mirrors_qa_backend/db/models.py b/backend/src/mirrors_qa_backend/db/models.py index a36098a..db41a66 100644 --- a/backend/src/mirrors_qa_backend/db/models.py +++ b/backend/src/mirrors_qa_backend/db/models.py @@ -97,9 +97,9 @@ class Worker(Base): # RSA public key in PKCS8 format for generating access tokens required # to make requests to the web server pubkey_pkcs8: Mapped[str] - pubkey_fingerprint: Mapped[str | None] = mapped_column(default=None) + pubkey_fingerprint: Mapped[str] - last_seen_on: Mapped[datetime | None] = mapped_column(default=None) + last_seen_on: Mapped[datetime] = mapped_column(default_factory=datetime.now) countries: Mapped[list[Country]] = relationship(back_populates="worker", init=False) @@ -108,7 +108,7 @@ class Test(Base): id: Mapped[UUID] = mapped_column( init=False, primary_key=True, server_default=text("uuid_generate_v4()") ) - requested_on: Mapped[datetime] + requested_on: Mapped[datetime] = mapped_column(default_factory=datetime.now) started_on: Mapped[datetime | None] = mapped_column(default=None) status: Mapped[StatusEnum] = mapped_column( Enum( diff --git a/backend/src/mirrors_qa_backend/entrypoint.py b/backend/src/mirrors_qa_backend/entrypoint.py new file mode 100644 index 0000000..673cae6 --- /dev/null +++ b/backend/src/mirrors_qa_backend/entrypoint.py @@ -0,0 +1,31 @@ +import argparse +import logging + +from mirrors_qa_backend import Settings, db, logger +from mirrors_qa_backend.db import mirrors +from mirrors_qa_backend.extract import get_current_mirrors + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--update-mirrors", + action="store_true", + help=f"Update the list of mirrors from {Settings.mirrors_url}", + ) + parser.add_argument( + "--verbose", "-v", help="Show verbose output", action="store_true" + ) + + args = parser.parse_args() + + if args.verbose: + logger.setLevel(logging.DEBUG) + + if args.update_mirrors: + with db.Session.begin() as session: + mirrors.update_mirrors(session, get_current_mirrors()) + + +if __name__ == "__main__": + main() diff --git a/backend/src/mirrors_qa_backend/exceptions.py b/backend/src/mirrors_qa_backend/exceptions.py new file mode 100644 index 0000000..38be96b --- /dev/null +++ b/backend/src/mirrors_qa_backend/exceptions.py @@ -0,0 +1,2 @@ +class EmptyMirrorsError(Exception): + pass diff --git a/backend/src/mirrors_qa_backend/extract.py b/backend/src/mirrors_qa_backend/extract.py new file mode 100644 index 0000000..093ddc9 --- /dev/null +++ b/backend/src/mirrors_qa_backend/extract.py @@ -0,0 +1,57 @@ +from typing import Any +from urllib.parse import urlsplit + +import pycountry +import requests +from bs4 import BeautifulSoup, NavigableString +from bs4.element import Tag + +from mirrors_qa_backend import logger, schemas +from mirrors_qa_backend.settings import Settings + + +def get_current_mirrors() -> list[schemas.Mirror]: + def find_country_rows(tag: Tag) -> bool: + """ + Filters out table rows that do not contain mirror + data from the table body. + """ + return tag.name == "tr" and tag.findChild("td", class_="newregion") is None + + resp = requests.get(Settings.mirrors_url, timeout=Settings.requests_timeout) + resp.raise_for_status() + + soup = BeautifulSoup(resp.text, features="html.parser") + body = soup.find("tbody") + + if body is None or isinstance(body, NavigableString | int): + raise ValueError + + mirrors: list[schemas.Mirror] = [] + + for row in body.find_all(find_country_rows): + base_url = row.find("a", string="HTTP")["href"] + hostname: Any = urlsplit( + base_url + ).netloc # pyright: ignore [reportUnknownMemberType] + if hostname in Settings.mirrors_exclusion_list: + continue + country_name = row.find("img").next_sibling.text.strip() + try: + country: Any = pycountry.countries.search_fuzzy(country_name)[0] + except LookupError: + logger.warning(f"Could not get information for country: {country_name!r}") + continue + else: + mirrors.append( + schemas.Mirror( + id=hostname, + base_url=base_url, + enabled=True, + country=schemas.Country( + code=country.alpha_2.lower(), + name=country.name, + ), + ) + ) + return mirrors diff --git a/backend/src/mirrors_qa_backend/migrations/versions/0c273daa1ab0_set_up_database_models.py b/backend/src/mirrors_qa_backend/migrations/versions/0c273daa1ab0_set_up_database_models.py index 44576c9..164b552 100644 --- a/backend/src/mirrors_qa_backend/migrations/versions/0c273daa1ab0_set_up_database_models.py +++ b/backend/src/mirrors_qa_backend/migrations/versions/0c273daa1ab0_set_up_database_models.py @@ -58,8 +58,8 @@ def upgrade() -> None: "worker", sa.Column("id", sa.String(), nullable=False), sa.Column("pubkey_pkcs8", sa.String(), nullable=False), - sa.Column("pubkey_fingerprint", sa.String(), nullable=True), - sa.Column("last_seen_on", sa.DateTime(), nullable=True), + sa.Column("pubkey_fingerprint", sa.String(), nullable=False), + sa.Column("last_seen_on", sa.DateTime(), nullable=False), sa.PrimaryKeyConstraint("id", name=op.f("pk_worker")), ) op.create_table( diff --git a/backend/src/mirrors_qa_backend/schemas.py b/backend/src/mirrors_qa_backend/schemas.py index 9cf9025..ce8bac3 100644 --- a/backend/src/mirrors_qa_backend/schemas.py +++ b/backend/src/mirrors_qa_backend/schemas.py @@ -6,6 +6,11 @@ class BaseModel(pydantic.BaseModel): model_config = ConfigDict(use_enum_values=True) +class Country(BaseModel): + code: str # two-letter country codes as defined in ISO 3166-1 + name: str # full name of country (in English) + + class Mirror(BaseModel): id: str # hostname of a mirror URL base_url: str @@ -19,9 +24,4 @@ class Mirror(BaseModel): region_only: bool | None = None as_only: bool | None = None other_countries: list[str] | None = None - - -class Country(BaseModel): - code: str # two-letter country codes as defined in ISO 3166-1 - name: str # full name of country (in English) - mirrors: list[Mirror] + country: Country diff --git a/backend/src/mirrors_qa_backend/settings.py b/backend/src/mirrors_qa_backend/settings.py index cd6d5a9..12243d4 100644 --- a/backend/src/mirrors_qa_backend/settings.py +++ b/backend/src/mirrors_qa_backend/settings.py @@ -15,9 +15,14 @@ class Settings: """Shared backend configuration""" database_url: str = getenv("POSTGRES_URI", mandatory=True) - mirrors_url = "https://download.kiwix.org/mirrors.html" - # comma-seperated list of mirror country names to exclude - mirrors_exclusion_list = getenv("EXCLUDED_MIRRORS", default="Israel").split(",") + # url to fetch the list of mirrors + mirrors_url: str = getenv( + "MIRRORS_LIST_URL", default="https://download.kiwix.org/mirrors.html" + ) + # comma-seperated list of mirror hostnames to exclude + mirrors_exclusion_list = getenv( + "EXCLUDED_MIRRORS", default="mirror.isoc.org.il" + ).split(",") debug = bool(getenv("DEBUG", default=False)) # number of seconds before requests time out requests_timeout = int(getenv("REQUESTS_TIMEOUT", default=5)) diff --git a/backend/tests/db/test_mirrors.py b/backend/tests/db/test_mirrors.py index 878350f..0de16bd 100644 --- a/backend/tests/db/test_mirrors.py +++ b/backend/tests/db/test_mirrors.py @@ -8,117 +8,113 @@ @pytest.fixture(scope="session") -def schema_country() -> schemas.Country: - return schemas.Country( - code="in", - name="India", - mirrors=[ - schemas.Mirror( - id="mirror-sites-in.mblibrary.info", - base_url="https://mirror-sites-in.mblibrary.info/mirror-sites/download.kiwix.org/", - enabled=True, - region=None, - asn=None, - score=None, - latitude=None, - longitude=None, - country_only=None, - region_only=None, - as_only=None, - other_countries=None, - ) - ], +def schema_mirror() -> schemas.Mirror: + return schemas.Mirror( + id="mirror-sites-in.mblibrary.info", + base_url="https://mirror-sites-in.mblibrary.info/mirror-sites/download.kiwix.org/", + enabled=True, + region=None, + asn=None, + score=None, + latitude=None, + longitude=None, + country_only=None, + region_only=None, + as_only=None, + other_countries=None, + country=schemas.Country( + code="in", + name="India", + ), ) -@pytest.fixture(scope="session") -def new_schema_country() -> schemas.Country: - return schemas.Country( - code="dk", - name="Denmark", - mirrors=[ - schemas.Mirror( - id="mirrors.dotsrc.org", - base_url="https://mirrors.dotsrc.org/kiwix/", - enabled=True, - region=None, - asn=None, - score=None, - latitude=None, - longitude=None, - country_only=None, - region_only=None, - as_only=None, - other_countries=None, - ) - ], +@pytest.fixture +def db_mirror() -> models.Mirror: + mirror = models.Mirror( + id="mirror-sites-in.mblibrary.info", + base_url="https://mirror-sites-in.mblibrary.info/mirror-sites/download.kiwix.org/", + enabled=True, + region=None, + asn=None, + score=None, + latitude=None, + longitude=None, + country_only=None, + region_only=None, + as_only=None, + other_countries=None, ) + mirror.country = models.Country(code="in", name="India") + return mirror -@pytest.fixture -def db_mirror_country() -> models.Country: - c = models.Country(code="in", name="India") - c.mirrors = [ - models.Mirror( - id="mirror-sites-in.mblibrary.info", - base_url="https://mirror-sites-in.mblibrary.info/mirror-sites/download.kiwix.org/", - enabled=True, - region=None, - asn=None, - score=None, - latitude=None, - longitude=None, - country_only=None, - region_only=None, - as_only=None, - other_countries=None, - ) - ] - return c +@pytest.fixture(scope="session") +def new_schema_mirror() -> schemas.Mirror: + return schemas.Mirror( + id="mirrors.dotsrc.org", + base_url="https://mirrors.dotsrc.org/kiwix/", + enabled=True, + region=None, + asn=None, + score=None, + latitude=None, + longitude=None, + country_only=None, + region_only=None, + as_only=None, + other_countries=None, + country=schemas.Country( + code="dk", + name="Denmark", + ), + ) def test_db_empty(dbsession: OrmSession): - count = db.count_from_stmt(dbsession, select(models.Country)) - assert count == 0 + assert db.count_from_stmt(dbsession, select(models.Country)) == 0 -def test_create_mirrors(dbsession: OrmSession, schema_country: schemas.Country): - mirrors.create_mirrors(dbsession, [schema_country]) - assert db.count_from_stmt(dbsession, select(models.Country)) == 1 +def test_create_no_mirrors(dbsession: OrmSession): + assert mirrors.create_mirrors(dbsession, []) == 0 -def test_all_mirrors_disabled(dbsession: OrmSession, db_mirror_country: models.Country): - dbsession.add(db_mirror_country) - mirrors.update_mirrors(dbsession, []) - assert ( - db.count_from_stmt( - dbsession, select(models.Mirror).where(models.Mirror.enabled == True) - ) - == 0 - ) +def test_create_mirrors(dbsession: OrmSession, schema_mirror: schemas.Mirror): + assert mirrors.create_mirrors(dbsession, [schema_mirror]) == 1 def test_register_new_country_mirror( dbsession: OrmSession, - schema_country: schemas.Country, - db_mirror_country: models.Country, - new_schema_country: schemas.Country, + schema_mirror: schemas.Mirror, + db_mirror: models.Mirror, + new_schema_mirror: schemas.Mirror, ): - dbsession.add(db_mirror_country) - mirrors.update_mirrors(dbsession, [schema_country, new_schema_country]) - assert db.count_from_stmt(dbsession, select(models.Mirror)) == 2 + dbsession.add(db_mirror) + result = mirrors.update_mirrors(dbsession, [schema_mirror, new_schema_mirror]) + assert result.nb_mirrors_added == 1 def test_disable_old_mirror( dbsession: OrmSession, - db_mirror_country: models.Country, - new_schema_country: schemas.Country, + db_mirror: models.Mirror, + new_schema_mirror: schemas.Mirror, ): - dbsession.add(db_mirror_country) - mirrors.update_mirrors(dbsession, [new_schema_country]) - assert ( - db.count_from_stmt( - dbsession, select(models.Mirror).where(models.Mirror.enabled == True) - ) - == 1 - ) + dbsession.add(db_mirror) + result = mirrors.update_mirrors(dbsession, [new_schema_mirror]) + assert result.nb_mirrors_disabled == 1 + + +def test_no_mirrors_disabled( + dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror +): + dbsession.add(db_mirror) + result = mirrors.update_mirrors(dbsession, [schema_mirror]) + assert result.nb_mirrors_disabled == 0 + + +def test_no_mirrors_added( + dbsession: OrmSession, db_mirror: models.Mirror, schema_mirror: schemas.Mirror +): + dbsession.add(db_mirror) + result = mirrors.update_mirrors(dbsession, [schema_mirror]) + assert result.nb_mirrors_added == 0