Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move wrappers module into networks #215

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: test
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trainer:
1: val
2: test
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: test
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trainer:
1: val
2: test
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
4 changes: 2 additions & 2 deletions configs/vision/tests/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ trainer:
1: val
2: test
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
repo_or_dir: facebookresearch/dino:main
model: dino_vits16
pretrained: true
pretrained: false
checkpoint_path: &CHECKPOINT_PATH ${oc.env:CHECKPOINT_PATH, null}
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
Expand Down
3 changes: 1 addition & 2 deletions src/eva/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Models API."""

from eva.models.modules import HeadModule, InferenceModule
from eva.models.networks import ModelFromFunction

__all__ = ["HeadModule", "ModelFromFunction", "InferenceModule"]
__all__ = ["HeadModule", "InferenceModule"]
4 changes: 2 additions & 2 deletions src/eva/models/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Networks API."""

from eva.models.networks import wrappers
from eva.models.networks.mlp import MLP
from eva.models.wrappers.from_function import ModelFromFunction

__all__ = ["ModelFromFunction", "MLP"]
__all__ = ["wrappers", "MLP"]
7 changes: 7 additions & 0 deletions src/eva/models/networks/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Model Wrappers API."""

from eva.models.networks.wrappers.from_function import ModelFromFunction
from eva.models.networks.wrappers.huggingface import HuggingFaceModel
from eva.models.networks.wrappers.onnx import ONNXModel

__all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel"]
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, tensor_transforms: Callable | None = None) -> None:
"""Initializes the model.

Args:
tensor_transforms: The transforms to apply to the output tensor produced by the model.
tensor_transforms: The transforms to apply to the output tensor
produced by the model.
"""
super().__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from torch import nn
from typing_extensions import override

from eva.models import wrappers
from eva.models.networks import _utils
from eva.models.networks.wrappers import base


class ModelFromFunction(wrappers.BaseModel):
class ModelFromFunction(base.BaseModel):
"""Wrapper class for models which are initialized from functions.

This is helpful for initializing models in a `.yaml` configuration file.
Expand All @@ -29,10 +29,12 @@ def __init__(
Args:
path: The path to the callable object (class or function).
arguments: The extra callable function / class arguments.
checkpoint_path: The path to the checkpoint to load the model weights from. This is
currently only supported for torch model checkpoints. For other formats, the
checkpoint loading should be handled within the provided callable object in <path>.
tensor_transforms: The transforms to apply to the output tensor produced by the model.
checkpoint_path: The path to the checkpoint to load the model
weights from. This is currently only supported for torch
model checkpoints. For other formats, the checkpoint loading
should be handled within the provided callable object in <path>.
tensor_transforms: The transforms to apply to the output tensor
produced by the model.
"""
super().__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
import transformers
from typing_extensions import override

from eva.models import wrappers
from eva.models.networks.wrappers import base


class HuggingFaceModel(wrappers.BaseModel):
class HuggingFaceModel(base.BaseModel):
"""Wrapper class for loading HuggingFace `transformers` models."""

def __init__(self, model_name_or_path: str, tensor_transforms: Callable | None = None) -> None:
"""Initializes the model.

Args:
model_name_or_path: The model name or path to load the model from. This can be a local
path or a model name from the HuggingFace model hub.
tensor_transforms: The transforms to apply to the output tensor produced by the model.
model_name_or_path: The model name or path to load the model from.
This can be a local path or a model name from the `HuggingFace`
model hub.
tensor_transforms: The transforms to apply to the output tensor
produced by the model.
"""
super().__init__(tensor_transforms=tensor_transforms)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch
from typing_extensions import override

from eva.models import wrappers
from eva.models.networks.wrappers import base


class ONNXModel(wrappers.BaseModel):
class ONNXModel(base.BaseModel):
"""Wrapper class for loading ONNX models."""

def __init__(
Expand Down
8 changes: 0 additions & 8 deletions src/eva/models/wrappers/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn

from eva.models import networks
from eva.models.networks import wrappers


@pytest.mark.parametrize(
Expand All @@ -17,7 +17,7 @@
],
)
def test_model_from_function(
model_from_function: networks.ModelFromFunction,
model_from_function: wrappers.ModelFromFunction,
) -> None:
"""Tests the model_from_function network."""
input_tensor = torch.Tensor(4, 10)
Expand All @@ -38,13 +38,13 @@ def test_error_model_from_function(
) -> None:
"""Tests the model_from_function network."""
with pytest.raises(TypeError):
networks.ModelFromFunction(path=path, arguments=arguments)
wrappers.ModelFromFunction(path=path, arguments=arguments)


@pytest.fixture(scope="function")
def model_from_function(
path: Callable[..., nn.Module],
arguments: Dict[str, Any] | None,
) -> networks.ModelFromFunction:
) -> wrappers.ModelFromFunction:
"""ModelFromFunction fixture."""
return networks.ModelFromFunction(path=path, arguments=arguments)
return wrappers.ModelFromFunction(path=path, arguments=arguments)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from transformers import modeling_outputs

from eva.models import wrappers
from eva.models.networks import wrappers
from eva.vision.data.transforms import ExtractCLSFeatures


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from pytorch_lightning.demos import boring_classes

from eva.models.wrappers import ONNXModel
from eva.models.networks import wrappers


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_onnx_model(
model_path: str, input_shape: Tuple[int, ...], expected_output_shape: Tuple[int, ...]
) -> None:
"""Tests the forward pass using the ONNXModel wrapper."""
model = ONNXModel(path=model_path)
model = wrappers.ONNXModel(path=model_path)
model.eval()

input_tensor = torch.rand(1, 32)
Expand Down