From f7148c870fc85e0ccf3c8b853759658c10b64f92 Mon Sep 17 00:00:00 2001 From: Harrison Cook <harrison.cook@ecmwf.int> Date: Sat, 14 Dec 2024 15:52:38 +0000 Subject: [PATCH] Add tests to huggingface loading --- src/anemoi/inference/checkpoint.py | 26 +++++----- tests/__init__.py | 0 tests/checkpoint/__init__.py | 0 tests/checkpoint/test_huggingface.py | 78 ++++++++++++++++++++++++++++ tests/metadata/fake_metadata.py | 3 ++ 5 files changed, 94 insertions(+), 13 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/checkpoint/__init__.py create mode 100644 tests/checkpoint/test_huggingface.py create mode 100644 tests/metadata/fake_metadata.py diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index fea51f3..2d62f8d 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -22,7 +22,7 @@ LOG = logging.getLogger(__name__) -def _download_huggingfacehub(huggingface_config): +def _download_huggingfacehub(huggingface_config) -> str: """Download model from huggingface""" try: from huggingface_hub import hf_hub_download @@ -34,17 +34,17 @@ def _download_huggingfacehub(huggingface_config): huggingface_config = {"repo_id": huggingface_config} if "filename" in huggingface_config: - config_path = hf_hub_download(**huggingface_config) + return str(hf_hub_download(**huggingface_config)) + + repo_path = Path(snapshot_download(**huggingface_config)) + ckpt_files = list(repo_path.glob("*.ckpt")) + + if len(ckpt_files) == 1: + return str(ckpt_files[0]) else: - repo_path = Path(snapshot_download(**huggingface_config)) - ckpt_files = list(repo_path.glob("*.ckpt")) - if len(ckpt_files) == 1: - return str(ckpt_files[0]) - else: - ValueError( - f"Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`." - ) - return config_path + raise ValueError( + f"None or Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`." + ) class Checkpoint: @@ -58,7 +58,7 @@ def __repr__(self): return f"Checkpoint({self.path})" @cached_property - def path(self): + def path(self) -> str: import json try: @@ -67,7 +67,7 @@ def path(self): path = self._path if isinstance(path, (Path, str)): - return path + return str(path) elif isinstance(path, dict): if "huggingface" in path: return _download_huggingfacehub(path["huggingface"]) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/checkpoint/__init__.py b/tests/checkpoint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/checkpoint/test_huggingface.py b/tests/checkpoint/test_huggingface.py new file mode 100644 index 0000000..65f5e54 --- /dev/null +++ b/tests/checkpoint/test_huggingface.py @@ -0,0 +1,78 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from unittest.mock import patch + +import pytest + +import anemoi.inference.checkpoint +from anemoi.inference.runner import Runner + +from ..metadata.fake_metadata import FakeMetadata + + +@pytest.fixture(scope="session") +def fake_huggingface_repo(tmp_path_factory): + """Create a fake huggingface repo download""" + tmp_dir = tmp_path_factory.mktemp("repo") + fn = tmp_dir / "model.ckpt" + fn.write_text("TESTING", encoding="utf-8") + return tmp_dir + + +@pytest.fixture(scope="session") +def fake_huggingface_ckpt(tmp_path_factory): + """Create a fake huggingface ckpt download""" + tmp_dir = tmp_path_factory.mktemp("repo") + fn = tmp_dir / "model.ckpt" + fn.write_text("TESTING", encoding="utf-8") + return fn + + +@patch("huggingface_hub.snapshot_download") +@pytest.mark.parametrize("ckpt", ["organisation/test_repo"]) +def test_huggingface_repo_download_str(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo): + + monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata()) + huggingface_mock.return_value = fake_huggingface_repo + + runner = Runner({"huggingface": ckpt}) + assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt") + + assert huggingface_mock.called + huggingface_mock.assert_called_once_with(repo_id=ckpt) + + +@patch("huggingface_hub.snapshot_download") +@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo"}]) +def test_huggingface_repo_download_dict(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo): + + monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata()) + huggingface_mock.return_value = fake_huggingface_repo + + runner = Runner({"huggingface": ckpt}) + assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt") + + assert huggingface_mock.called + huggingface_mock.assert_called_once_with(**ckpt) + + +@patch("huggingface_hub.hf_hub_download") +@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo", "filename": "model.ckpt"}]) +def test_huggingface_file_download(huggingface_mock, monkeypatch, ckpt, fake_huggingface_ckpt): + + monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata()) + huggingface_mock.return_value = fake_huggingface_ckpt + + runner = Runner({"huggingface": ckpt}) + assert runner.checkpoint.path == str(fake_huggingface_ckpt) + + assert huggingface_mock.called + huggingface_mock.assert_called_once_with(**ckpt) diff --git a/tests/metadata/fake_metadata.py b/tests/metadata/fake_metadata.py new file mode 100644 index 0000000..6fa8d0c --- /dev/null +++ b/tests/metadata/fake_metadata.py @@ -0,0 +1,3 @@ +class FakeMetadata: + def __getattr__(self, name): + return None