Skip to content

Commit

Permalink
Add tests to huggingface loading
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Dec 14, 2024
1 parent 29aab25 commit f7148c8
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
LOG = logging.getLogger(__name__)


def _download_huggingfacehub(huggingface_config):
def _download_huggingfacehub(huggingface_config) -> str:
"""Download model from huggingface"""
try:
from huggingface_hub import hf_hub_download
Expand All @@ -34,17 +34,17 @@ def _download_huggingfacehub(huggingface_config):
huggingface_config = {"repo_id": huggingface_config}

if "filename" in huggingface_config:
config_path = hf_hub_download(**huggingface_config)
return str(hf_hub_download(**huggingface_config))

repo_path = Path(snapshot_download(**huggingface_config))
ckpt_files = list(repo_path.glob("*.ckpt"))

if len(ckpt_files) == 1:
return str(ckpt_files[0])
else:
repo_path = Path(snapshot_download(**huggingface_config))
ckpt_files = list(repo_path.glob("*.ckpt"))
if len(ckpt_files) == 1:
return str(ckpt_files[0])
else:
ValueError(
f"Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`."
)
return config_path
raise ValueError(
f"None or Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`."
)


class Checkpoint:
Expand All @@ -58,7 +58,7 @@ def __repr__(self):
return f"Checkpoint({self.path})"

@cached_property
def path(self):
def path(self) -> str:
import json

try:
Expand All @@ -67,7 +67,7 @@ def path(self):
path = self._path

if isinstance(path, (Path, str)):
return path
return str(path)
elif isinstance(path, dict):
if "huggingface" in path:
return _download_huggingfacehub(path["huggingface"])
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/checkpoint/__init__.py
Empty file.
78 changes: 78 additions & 0 deletions tests/checkpoint/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from unittest.mock import patch

import pytest

import anemoi.inference.checkpoint
from anemoi.inference.runner import Runner

from ..metadata.fake_metadata import FakeMetadata


@pytest.fixture(scope="session")
def fake_huggingface_repo(tmp_path_factory):
"""Create a fake huggingface repo download"""
tmp_dir = tmp_path_factory.mktemp("repo")
fn = tmp_dir / "model.ckpt"
fn.write_text("TESTING", encoding="utf-8")
return tmp_dir


@pytest.fixture(scope="session")
def fake_huggingface_ckpt(tmp_path_factory):
"""Create a fake huggingface ckpt download"""
tmp_dir = tmp_path_factory.mktemp("repo")
fn = tmp_dir / "model.ckpt"
fn.write_text("TESTING", encoding="utf-8")
return fn


@patch("huggingface_hub.snapshot_download")
@pytest.mark.parametrize("ckpt", ["organisation/test_repo"])
def test_huggingface_repo_download_str(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo):

monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
huggingface_mock.return_value = fake_huggingface_repo

runner = Runner({"huggingface": ckpt})
assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt")

assert huggingface_mock.called
huggingface_mock.assert_called_once_with(repo_id=ckpt)


@patch("huggingface_hub.snapshot_download")
@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo"}])
def test_huggingface_repo_download_dict(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo):

monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
huggingface_mock.return_value = fake_huggingface_repo

runner = Runner({"huggingface": ckpt})
assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt")

assert huggingface_mock.called
huggingface_mock.assert_called_once_with(**ckpt)


@patch("huggingface_hub.hf_hub_download")
@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo", "filename": "model.ckpt"}])
def test_huggingface_file_download(huggingface_mock, monkeypatch, ckpt, fake_huggingface_ckpt):

monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
huggingface_mock.return_value = fake_huggingface_ckpt

runner = Runner({"huggingface": ckpt})
assert runner.checkpoint.path == str(fake_huggingface_ckpt)

assert huggingface_mock.called
huggingface_mock.assert_called_once_with(**ckpt)
3 changes: 3 additions & 0 deletions tests/metadata/fake_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class FakeMetadata:
def __getattr__(self, name):
return None

0 comments on commit f7148c8

Please sign in to comment.