From f7148c870fc85e0ccf3c8b853759658c10b64f92 Mon Sep 17 00:00:00 2001
From: Harrison Cook <harrison.cook@ecmwf.int>
Date: Sat, 14 Dec 2024 15:52:38 +0000
Subject: [PATCH] Add tests to huggingface loading

---
 src/anemoi/inference/checkpoint.py   | 26 +++++-----
 tests/__init__.py                    |  0
 tests/checkpoint/__init__.py         |  0
 tests/checkpoint/test_huggingface.py | 78 ++++++++++++++++++++++++++++
 tests/metadata/fake_metadata.py      |  3 ++
 5 files changed, 94 insertions(+), 13 deletions(-)
 create mode 100644 tests/__init__.py
 create mode 100644 tests/checkpoint/__init__.py
 create mode 100644 tests/checkpoint/test_huggingface.py
 create mode 100644 tests/metadata/fake_metadata.py

diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py
index fea51f3..2d62f8d 100644
--- a/src/anemoi/inference/checkpoint.py
+++ b/src/anemoi/inference/checkpoint.py
@@ -22,7 +22,7 @@
 LOG = logging.getLogger(__name__)
 
 
-def _download_huggingfacehub(huggingface_config):
+def _download_huggingfacehub(huggingface_config) -> str:
     """Download model from huggingface"""
     try:
         from huggingface_hub import hf_hub_download
@@ -34,17 +34,17 @@ def _download_huggingfacehub(huggingface_config):
         huggingface_config = {"repo_id": huggingface_config}
 
     if "filename" in huggingface_config:
-        config_path = hf_hub_download(**huggingface_config)
+        return str(hf_hub_download(**huggingface_config))
+
+    repo_path = Path(snapshot_download(**huggingface_config))
+    ckpt_files = list(repo_path.glob("*.ckpt"))
+
+    if len(ckpt_files) == 1:
+        return str(ckpt_files[0])
     else:
-        repo_path = Path(snapshot_download(**huggingface_config))
-        ckpt_files = list(repo_path.glob("*.ckpt"))
-        if len(ckpt_files) == 1:
-            return str(ckpt_files[0])
-        else:
-            ValueError(
-                f"Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`."
-            )
-    return config_path
+        raise ValueError(
+            f"None or Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`."
+        )
 
 
 class Checkpoint:
@@ -58,7 +58,7 @@ def __repr__(self):
         return f"Checkpoint({self.path})"
 
     @cached_property
-    def path(self):
+    def path(self) -> str:
         import json
 
         try:
@@ -67,7 +67,7 @@ def path(self):
             path = self._path
 
         if isinstance(path, (Path, str)):
-            return path
+            return str(path)
         elif isinstance(path, dict):
             if "huggingface" in path:
                 return _download_huggingfacehub(path["huggingface"])
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/checkpoint/__init__.py b/tests/checkpoint/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/checkpoint/test_huggingface.py b/tests/checkpoint/test_huggingface.py
new file mode 100644
index 0000000..65f5e54
--- /dev/null
+++ b/tests/checkpoint/test_huggingface.py
@@ -0,0 +1,78 @@
+# (C) Copyright 2024 Anemoi contributors.
+#
+# This software is licensed under the terms of the Apache Licence Version 2.0
+# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
+#
+# In applying this licence, ECMWF does not waive the privileges and immunities
+# granted to it by virtue of its status as an intergovernmental organisation
+# nor does it submit to any jurisdiction.
+
+
+from unittest.mock import patch
+
+import pytest
+
+import anemoi.inference.checkpoint
+from anemoi.inference.runner import Runner
+
+from ..metadata.fake_metadata import FakeMetadata
+
+
+@pytest.fixture(scope="session")
+def fake_huggingface_repo(tmp_path_factory):
+    """Create a fake huggingface repo download"""
+    tmp_dir = tmp_path_factory.mktemp("repo")
+    fn = tmp_dir / "model.ckpt"
+    fn.write_text("TESTING", encoding="utf-8")
+    return tmp_dir
+
+
+@pytest.fixture(scope="session")
+def fake_huggingface_ckpt(tmp_path_factory):
+    """Create a fake huggingface ckpt download"""
+    tmp_dir = tmp_path_factory.mktemp("repo")
+    fn = tmp_dir / "model.ckpt"
+    fn.write_text("TESTING", encoding="utf-8")
+    return fn
+
+
+@patch("huggingface_hub.snapshot_download")
+@pytest.mark.parametrize("ckpt", ["organisation/test_repo"])
+def test_huggingface_repo_download_str(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo):
+
+    monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
+    huggingface_mock.return_value = fake_huggingface_repo
+
+    runner = Runner({"huggingface": ckpt})
+    assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt")
+
+    assert huggingface_mock.called
+    huggingface_mock.assert_called_once_with(repo_id=ckpt)
+
+
+@patch("huggingface_hub.snapshot_download")
+@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo"}])
+def test_huggingface_repo_download_dict(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo):
+
+    monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
+    huggingface_mock.return_value = fake_huggingface_repo
+
+    runner = Runner({"huggingface": ckpt})
+    assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt")
+
+    assert huggingface_mock.called
+    huggingface_mock.assert_called_once_with(**ckpt)
+
+
+@patch("huggingface_hub.hf_hub_download")
+@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo", "filename": "model.ckpt"}])
+def test_huggingface_file_download(huggingface_mock, monkeypatch, ckpt, fake_huggingface_ckpt):
+
+    monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
+    huggingface_mock.return_value = fake_huggingface_ckpt
+
+    runner = Runner({"huggingface": ckpt})
+    assert runner.checkpoint.path == str(fake_huggingface_ckpt)
+
+    assert huggingface_mock.called
+    huggingface_mock.assert_called_once_with(**ckpt)
diff --git a/tests/metadata/fake_metadata.py b/tests/metadata/fake_metadata.py
new file mode 100644
index 0000000..6fa8d0c
--- /dev/null
+++ b/tests/metadata/fake_metadata.py
@@ -0,0 +1,3 @@
+class FakeMetadata:
+    def __getattr__(self, name):
+        return None