diff --git a/configs/vision/dino_vit/offline/bach.yaml b/configs/vision/dino_vit/offline/bach.yaml index 610338ed..db0ca4bd 100644 --- a/configs/vision/dino_vit/offline/bach.yaml +++ b/configs/vision/dino_vit/offline/bach.yaml @@ -29,7 +29,7 @@ trainer: 0: train 1: val backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/dino_vit/offline/crc.yaml b/configs/vision/dino_vit/offline/crc.yaml index 523e2818..a5eec1bd 100644 --- a/configs/vision/dino_vit/offline/crc.yaml +++ b/configs/vision/dino_vit/offline/crc.yaml @@ -29,7 +29,7 @@ trainer: 0: train 1: val backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/dino_vit/offline/mhist.yaml b/configs/vision/dino_vit/offline/mhist.yaml index 45e01a0e..2445af22 100644 --- a/configs/vision/dino_vit/offline/mhist.yaml +++ b/configs/vision/dino_vit/offline/mhist.yaml @@ -29,7 +29,7 @@ trainer: 0: train 1: test backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/dino_vit/offline/patch_camelyon.yaml b/configs/vision/dino_vit/offline/patch_camelyon.yaml index b695a1b2..da21c399 100644 --- a/configs/vision/dino_vit/offline/patch_camelyon.yaml +++ b/configs/vision/dino_vit/offline/patch_camelyon.yaml @@ -30,7 +30,7 @@ trainer: 1: val 2: test backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/dino_vit/online/bach.yaml b/configs/vision/dino_vit/online/bach.yaml index 32648c55..b89164d4 100644 --- a/configs/vision/dino_vit/online/bach.yaml +++ b/configs/vision/dino_vit/online/bach.yaml @@ -30,7 +30,7 @@ model: class_path: eva.HeadModule init_args: backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/dino_vit/online/crc.yaml b/configs/vision/dino_vit/online/crc.yaml index b94a272c..7aa965fc 100644 --- a/configs/vision/dino_vit/online/crc.yaml +++ b/configs/vision/dino_vit/online/crc.yaml @@ -30,7 +30,7 @@ model: class_path: eva.HeadModule init_args: backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/dino_vit/online/mhist.yaml b/configs/vision/dino_vit/online/mhist.yaml index a53056c1..9ce77d38 100644 --- a/configs/vision/dino_vit/online/mhist.yaml +++ b/configs/vision/dino_vit/online/mhist.yaml @@ -30,7 +30,7 @@ model: class_path: eva.HeadModule init_args: backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/dino_vit/online/patch_camelyon.yaml b/configs/vision/dino_vit/online/patch_camelyon.yaml index e27501d3..7d26b25b 100644 --- a/configs/vision/dino_vit/online/patch_camelyon.yaml +++ b/configs/vision/dino_vit/online/patch_camelyon.yaml @@ -30,7 +30,7 @@ model: class_path: eva.HeadModule init_args: backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: diff --git a/configs/vision/owkin/phikon/offline/bach.yaml b/configs/vision/owkin/phikon/offline/bach.yaml index 14badfc0..bc5a51c6 100644 --- a/configs/vision/owkin/phikon/offline/bach.yaml +++ b/configs/vision/owkin/phikon/offline/bach.yaml @@ -29,7 +29,7 @@ trainer: 0: train 1: val backbone: - class_path: eva.models.wrappers.HuggingFaceModel + class_path: eva.models.networks.wrappers.HuggingFaceModel init_args: model_name_or_path: owkin/phikon tensor_transforms: diff --git a/configs/vision/owkin/phikon/offline/crc.yaml b/configs/vision/owkin/phikon/offline/crc.yaml index 80b41400..d71c0515 100644 --- a/configs/vision/owkin/phikon/offline/crc.yaml +++ b/configs/vision/owkin/phikon/offline/crc.yaml @@ -29,7 +29,7 @@ trainer: 0: train 1: val backbone: - class_path: eva.models.wrappers.HuggingFaceModel + class_path: eva.models.networks.wrappers.HuggingFaceModel init_args: model_name_or_path: owkin/phikon tensor_transforms: diff --git a/configs/vision/owkin/phikon/offline/mhist.yaml b/configs/vision/owkin/phikon/offline/mhist.yaml index 2ce63457..8e6da687 100644 --- a/configs/vision/owkin/phikon/offline/mhist.yaml +++ b/configs/vision/owkin/phikon/offline/mhist.yaml @@ -29,7 +29,7 @@ trainer: 0: train 1: test backbone: - class_path: eva.models.wrappers.HuggingFaceModel + class_path: eva.models.networks.wrappers.HuggingFaceModel init_args: model_name_or_path: owkin/phikon tensor_transforms: diff --git a/configs/vision/owkin/phikon/offline/patch_camelyon.yaml b/configs/vision/owkin/phikon/offline/patch_camelyon.yaml index 53dc5358..640dec10 100644 --- a/configs/vision/owkin/phikon/offline/patch_camelyon.yaml +++ b/configs/vision/owkin/phikon/offline/patch_camelyon.yaml @@ -30,7 +30,7 @@ trainer: 1: val 2: test backbone: - class_path: eva.models.wrappers.HuggingFaceModel + class_path: eva.models.networks.wrappers.HuggingFaceModel init_args: model_name_or_path: owkin/phikon tensor_transforms: diff --git a/configs/vision/tests/offline/patch_camelyon.yaml b/configs/vision/tests/offline/patch_camelyon.yaml index 013ebc79..6ab2de9f 100644 --- a/configs/vision/tests/offline/patch_camelyon.yaml +++ b/configs/vision/tests/offline/patch_camelyon.yaml @@ -15,13 +15,13 @@ trainer: 1: val 2: test backbone: - class_path: eva.models.ModelFromFunction + class_path: eva.models.networks.wrappers.ModelFromFunction init_args: path: torch.hub.load arguments: repo_or_dir: facebookresearch/dino:main model: dino_vits16 - pretrained: true + pretrained: false checkpoint_path: &CHECKPOINT_PATH ${oc.env:CHECKPOINT_PATH, null} - class_path: pytorch_lightning.callbacks.LearningRateMonitor init_args: diff --git a/src/eva/models/__init__.py b/src/eva/models/__init__.py index ea0acbab..c06811ed 100644 --- a/src/eva/models/__init__.py +++ b/src/eva/models/__init__.py @@ -1,6 +1,5 @@ """Models API.""" from eva.models.modules import HeadModule, InferenceModule -from eva.models.networks import ModelFromFunction -__all__ = ["HeadModule", "ModelFromFunction", "InferenceModule"] +__all__ = ["HeadModule", "InferenceModule"] diff --git a/src/eva/models/networks/__init__.py b/src/eva/models/networks/__init__.py index 54bf968d..bb113d55 100644 --- a/src/eva/models/networks/__init__.py +++ b/src/eva/models/networks/__init__.py @@ -1,6 +1,6 @@ """Networks API.""" +from eva.models.networks import wrappers from eva.models.networks.mlp import MLP -from eva.models.wrappers.from_function import ModelFromFunction -__all__ = ["ModelFromFunction", "MLP"] +__all__ = ["wrappers", "MLP"] diff --git a/src/eva/models/networks/wrappers/__init__.py b/src/eva/models/networks/wrappers/__init__.py new file mode 100644 index 00000000..c86f1f04 --- /dev/null +++ b/src/eva/models/networks/wrappers/__init__.py @@ -0,0 +1,7 @@ +"""Model Wrappers API.""" + +from eva.models.networks.wrappers.from_function import ModelFromFunction +from eva.models.networks.wrappers.huggingface import HuggingFaceModel +from eva.models.networks.wrappers.onnx import ONNXModel + +__all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel"] diff --git a/src/eva/models/wrappers/base.py b/src/eva/models/networks/wrappers/base.py similarity index 100% rename from src/eva/models/wrappers/base.py rename to src/eva/models/networks/wrappers/base.py diff --git a/src/eva/models/wrappers/from_function.py b/src/eva/models/networks/wrappers/from_function.py similarity index 95% rename from src/eva/models/wrappers/from_function.py rename to src/eva/models/networks/wrappers/from_function.py index 7295d5c1..97a5087e 100644 --- a/src/eva/models/wrappers/from_function.py +++ b/src/eva/models/networks/wrappers/from_function.py @@ -7,11 +7,11 @@ from torch import nn from typing_extensions import override -from eva.models import wrappers from eva.models.networks import _utils +from eva.models.networks.wrappers import base -class ModelFromFunction(wrappers.BaseModel): +class ModelFromFunction(base.BaseModel): """Wrapper class for models which are initialized from functions. This is helpful for initializing models in a `.yaml` configuration file. diff --git a/src/eva/models/wrappers/huggingface.py b/src/eva/models/networks/wrappers/huggingface.py similarity index 93% rename from src/eva/models/wrappers/huggingface.py rename to src/eva/models/networks/wrappers/huggingface.py index 5fdb652a..f3ac014e 100644 --- a/src/eva/models/wrappers/huggingface.py +++ b/src/eva/models/networks/wrappers/huggingface.py @@ -6,10 +6,10 @@ import transformers from typing_extensions import override -from eva.models import wrappers +from eva.models.networks.wrappers import base -class HuggingFaceModel(wrappers.BaseModel): +class HuggingFaceModel(base.BaseModel): """Wrapper class for loading HuggingFace `transformers` models.""" def __init__(self, model_name_or_path: str, tensor_transforms: Callable | None = None) -> None: diff --git a/src/eva/models/wrappers/onnx.py b/src/eva/models/networks/wrappers/onnx.py similarity index 95% rename from src/eva/models/wrappers/onnx.py rename to src/eva/models/networks/wrappers/onnx.py index a96e0b43..cd973ea2 100644 --- a/src/eva/models/wrappers/onnx.py +++ b/src/eva/models/networks/wrappers/onnx.py @@ -6,10 +6,10 @@ import torch from typing_extensions import override -from eva.models import wrappers +from eva.models.networks.wrappers import base -class ONNXModel(wrappers.BaseModel): +class ONNXModel(base.BaseModel): """Wrapper class for loading ONNX models.""" def __init__( diff --git a/src/eva/models/wrappers/__init__.py b/src/eva/models/wrappers/__init__.py deleted file mode 100644 index dbe6d4de..00000000 --- a/src/eva/models/wrappers/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Model Wrappers API.""" - -from eva.models.wrappers.base import BaseModel -from eva.models.wrappers.from_function import ModelFromFunction -from eva.models.wrappers.huggingface import HuggingFaceModel -from eva.models.wrappers.onnx import ONNXModel - -__all__ = ["BaseModel", "HuggingFaceModel", "ONNXModel", "ModelFromFunction"] diff --git a/tests/eva/models/wrappers/__init__.py b/tests/eva/models/networks/wrappers/__init__.py similarity index 100% rename from tests/eva/models/wrappers/__init__.py rename to tests/eva/models/networks/wrappers/__init__.py diff --git a/tests/eva/models/wrappers/test_from_function.py b/tests/eva/models/networks/wrappers/test_from_function.py similarity index 80% rename from tests/eva/models/wrappers/test_from_function.py rename to tests/eva/models/networks/wrappers/test_from_function.py index 935eba75..b1f0cb3a 100644 --- a/tests/eva/models/wrappers/test_from_function.py +++ b/tests/eva/models/networks/wrappers/test_from_function.py @@ -6,7 +6,7 @@ import torch from torch import nn -from eva.models import networks +from eva.models.networks import wrappers @pytest.mark.parametrize( @@ -17,7 +17,7 @@ ], ) def test_model_from_function( - model_from_function: networks.ModelFromFunction, + model_from_function: wrappers.ModelFromFunction, ) -> None: """Tests the model_from_function network.""" input_tensor = torch.Tensor(4, 10) @@ -38,13 +38,13 @@ def test_error_model_from_function( ) -> None: """Tests the model_from_function network.""" with pytest.raises(TypeError): - networks.ModelFromFunction(path=path, arguments=arguments) + wrappers.ModelFromFunction(path=path, arguments=arguments) @pytest.fixture(scope="function") def model_from_function( path: Callable[..., nn.Module], arguments: Dict[str, Any] | None, -) -> networks.ModelFromFunction: +) -> wrappers.ModelFromFunction: """ModelFromFunction fixture.""" - return networks.ModelFromFunction(path=path, arguments=arguments) + return wrappers.ModelFromFunction(path=path, arguments=arguments) diff --git a/tests/eva/models/wrappers/test_huggingface.py b/tests/eva/models/networks/wrappers/test_huggingface.py similarity index 96% rename from tests/eva/models/wrappers/test_huggingface.py rename to tests/eva/models/networks/wrappers/test_huggingface.py index 7c71e6af..be5d5f37 100644 --- a/tests/eva/models/wrappers/test_huggingface.py +++ b/tests/eva/models/networks/wrappers/test_huggingface.py @@ -6,7 +6,7 @@ import torch from transformers import modeling_outputs -from eva.models import wrappers +from eva.models.networks import wrappers from eva.vision.data.transforms import ExtractCLSFeatures diff --git a/tests/eva/models/wrappers/test_onnx.py b/tests/eva/models/networks/wrappers/test_onnx.py similarity index 92% rename from tests/eva/models/wrappers/test_onnx.py rename to tests/eva/models/networks/wrappers/test_onnx.py index 3a0f8a76..656daad0 100644 --- a/tests/eva/models/wrappers/test_onnx.py +++ b/tests/eva/models/networks/wrappers/test_onnx.py @@ -8,7 +8,7 @@ import torch from pytorch_lightning.demos import boring_classes -from eva.models.wrappers import ONNXModel +from eva.models.networks import wrappers @pytest.mark.parametrize( @@ -22,7 +22,7 @@ def test_onnx_model( model_path: str, input_shape: Tuple[int, ...], expected_output_shape: Tuple[int, ...] ) -> None: """Tests the forward pass using the ONNXModel wrapper.""" - model = ONNXModel(path=model_path) + model = wrappers.ONNXModel(path=model_path) model.eval() input_tensor = torch.rand(1, 32)