Skip to content

Commit

Permalink
move wrappers to networks
Browse files Browse the repository at this point in the history
  • Loading branch information
ioangatop committed Mar 10, 2024
1 parent 6fcd41d commit 225b277
Show file tree
Hide file tree
Showing 25 changed files with 38 additions and 40 deletions.
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: test
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trainer:
1: val
2: test
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/dino_vit/online/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model:
class_path: eva.HeadModule
init_args:
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/bach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/crc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: val
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/mhist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
0: train
1: test
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
2 changes: 1 addition & 1 deletion configs/vision/owkin/phikon/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trainer:
1: val
2: test
backbone:
class_path: eva.models.wrappers.HuggingFaceModel
class_path: eva.models.networks.wrappers.HuggingFaceModel
init_args:
model_name_or_path: owkin/phikon
tensor_transforms:
Expand Down
4 changes: 2 additions & 2 deletions configs/vision/tests/offline/patch_camelyon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ trainer:
1: val
2: test
backbone:
class_path: eva.models.ModelFromFunction
class_path: eva.models.networks.wrappers.ModelFromFunction
init_args:
path: torch.hub.load
arguments:
repo_or_dir: facebookresearch/dino:main
model: dino_vits16
pretrained: true
pretrained: false
checkpoint_path: &CHECKPOINT_PATH ${oc.env:CHECKPOINT_PATH, null}
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
Expand Down
3 changes: 1 addition & 2 deletions src/eva/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Models API."""

from eva.models.modules import HeadModule, InferenceModule
from eva.models.networks import ModelFromFunction

__all__ = ["HeadModule", "ModelFromFunction", "InferenceModule"]
__all__ = ["HeadModule", "InferenceModule"]
4 changes: 2 additions & 2 deletions src/eva/models/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Networks API."""

from eva.models.networks import wrappers
from eva.models.networks.mlp import MLP
from eva.models.wrappers.from_function import ModelFromFunction

__all__ = ["ModelFromFunction", "MLP"]
__all__ = ["wrappers", "MLP"]
7 changes: 7 additions & 0 deletions src/eva/models/networks/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Model Wrappers API."""

from eva.models.networks.wrappers.from_function import ModelFromFunction
from eva.models.networks.wrappers.huggingface import HuggingFaceModel
from eva.models.networks.wrappers.onnx import ONNXModel

__all__ = ["ModelFromFunction", "HuggingFaceModel", "ONNXModel"]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from torch import nn
from typing_extensions import override

from eva.models import wrappers
from eva.models.networks import _utils
from eva.models.networks.wrappers import base


class ModelFromFunction(wrappers.BaseModel):
class ModelFromFunction(base.BaseModel):
"""Wrapper class for models which are initialized from functions.
This is helpful for initializing models in a `.yaml` configuration file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import transformers
from typing_extensions import override

from eva.models import wrappers
from eva.models.networks.wrappers import base


class HuggingFaceModel(wrappers.BaseModel):
class HuggingFaceModel(base.BaseModel):
"""Wrapper class for loading HuggingFace `transformers` models."""

def __init__(self, model_name_or_path: str, tensor_transforms: Callable | None = None) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch
from typing_extensions import override

from eva.models import wrappers
from eva.models.networks.wrappers import base


class ONNXModel(wrappers.BaseModel):
class ONNXModel(base.BaseModel):
"""Wrapper class for loading ONNX models."""

def __init__(
Expand Down
8 changes: 0 additions & 8 deletions src/eva/models/wrappers/__init__.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import nn

from eva.models import networks
from eva.models.networks import wrappers


@pytest.mark.parametrize(
Expand All @@ -17,7 +17,7 @@
],
)
def test_model_from_function(
model_from_function: networks.ModelFromFunction,
model_from_function: wrappers.ModelFromFunction,
) -> None:
"""Tests the model_from_function network."""
input_tensor = torch.Tensor(4, 10)
Expand All @@ -38,13 +38,13 @@ def test_error_model_from_function(
) -> None:
"""Tests the model_from_function network."""
with pytest.raises(TypeError):
networks.ModelFromFunction(path=path, arguments=arguments)
wrappers.ModelFromFunction(path=path, arguments=arguments)


@pytest.fixture(scope="function")
def model_from_function(
path: Callable[..., nn.Module],
arguments: Dict[str, Any] | None,
) -> networks.ModelFromFunction:
) -> wrappers.ModelFromFunction:
"""ModelFromFunction fixture."""
return networks.ModelFromFunction(path=path, arguments=arguments)
return wrappers.ModelFromFunction(path=path, arguments=arguments)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from transformers import modeling_outputs

from eva.models import wrappers
from eva.models.networks import wrappers
from eva.vision.data.transforms import ExtractCLSFeatures


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from pytorch_lightning.demos import boring_classes

from eva.models.wrappers import ONNXModel
from eva.models.networks import wrappers


@pytest.mark.parametrize(
Expand All @@ -22,7 +22,7 @@ def test_onnx_model(
model_path: str, input_shape: Tuple[int, ...], expected_output_shape: Tuple[int, ...]
) -> None:
"""Tests the forward pass using the ONNXModel wrapper."""
model = ONNXModel(path=model_path)
model = wrappers.ONNXModel(path=model_path)
model.eval()

input_tensor = torch.rand(1, 32)
Expand Down

0 comments on commit 225b277

Please sign in to comment.