diff --git a/configs/vision/pathology/online/segmentation/consep.yaml b/configs/vision/pathology/online/segmentation/consep.yaml index 06f181df..4935515b 100644 --- a/configs/vision/pathology/online/segmentation/consep.yaml +++ b/configs/vision/pathology/online/segmentation/consep.yaml @@ -3,8 +3,8 @@ trainer: class_path: eva.Trainer init_args: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/consep} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 callbacks: - class_path: eva.callbacks.ConfigurationLogger @@ -26,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: 34 monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: @@ -45,10 +45,10 @@ model: out_indices: ${oc.env:OUT_INDICES, 1} model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} decoder: - class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage init_args: in_features: ${oc.env:IN_FEATURES, 384} - num_classes: &NUM_CLASSES 5 + num_classes: &NUM_CLASSES 5 criterion: class_path: eva.vision.losses.DiceLoss init_args: @@ -58,7 +58,7 @@ model: optimizer: class_path: torch.optim.AdamW init_args: - lr: ${oc.env:LR_VALUE, 0.002} + lr: ${oc.env:LR_VALUE, 0.0001} lr_scheduler: class_path: torch.optim.lr_scheduler.PolynomialLR init_args: diff --git a/configs/vision/pathology/online/segmentation/monusac.yaml b/configs/vision/pathology/online/segmentation/monusac.yaml index b7f7ec21..acf8d9e1 100644 --- a/configs/vision/pathology/online/segmentation/monusac.yaml +++ b/configs/vision/pathology/online/segmentation/monusac.yaml @@ -3,9 +3,9 @@ trainer: class_path: eva.Trainer init_args: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/monusac} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 550} - log_every_n_steps: 4 + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} + log_every_n_steps: 6 callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: 50 monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: @@ -45,10 +45,10 @@ model: out_indices: ${oc.env:OUT_INDICES, 1} model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} decoder: - class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage init_args: in_features: ${oc.env:IN_FEATURES, 384} - num_classes: &NUM_CLASSES 5 + num_classes: &NUM_CLASSES 5 criterion: class_path: eva.vision.losses.DiceLoss init_args: @@ -59,7 +59,7 @@ model: optimizer: class_path: torch.optim.AdamW init_args: - lr: ${oc.env:LR_VALUE, 0.002} + lr: ${oc.env:LR_VALUE, 0.0001} lr_scheduler: class_path: torch.optim.lr_scheduler.PolynomialLR init_args: diff --git a/docs/images/leaderboard.svg b/docs/images/leaderboard.svg index 2031c979..f447a356 100644 --- a/docs/images/leaderboard.svg +++ b/docs/images/leaderboard.svg @@ -6,7 +6,7 @@ - 2024-10-18T15:48:36.884888 + 2024-11-21T11:04:41.708790 image/svg+xml @@ -40,532 +40,628 @@ z +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #3a3aff"/> +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #7a7aff"/> +" clip-path="url(#pe67134274c)" style="fill: #1010ff"/> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + - + - + - + - + +L 631.971875 300.838362 +" clip-path="url(#pe67134274c)" style="fill: #fafaff"/> @@ -1106,30 +1202,26 @@ z - - + - - + + - - + + + + @@ -1181,53 +1363,33 @@ Q 359 3434 948 4092 Q 1538 4750 2522 4750 z " transform="scale(0.015625)"/> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - + - + - + @@ -1823,15 +1850,163 @@ z + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - + + @@ -1852,27 +2027,27 @@ z - - - - - - - - - + + + + + + + + + - + - + @@ -1911,12 +2086,53 @@ z - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + @@ -1950,32 +2166,16 @@ z - - + + - + - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - + - + @@ -2103,15 +2269,82 @@ z - - + + - + - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -2120,37 +2353,42 @@ z - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - - - + + + - - + + - - - + + + - - - + + - - - - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + + + + - + - - - + + + - - - + + + - + - + @@ -2304,9 +2510,9 @@ z - + - + @@ -2314,9 +2520,9 @@ z - + - + @@ -2324,9 +2530,9 @@ z - + - + @@ -2334,9 +2540,9 @@ z - + - + @@ -2344,9 +2550,9 @@ z - + - + @@ -2354,109 +2560,109 @@ z - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - + + - - - + + + - + - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - - + + + - - - + + + - - - + + + - + - + @@ -2464,9 +2670,9 @@ z - + - + @@ -2474,9 +2680,9 @@ z - + - + @@ -2484,9 +2690,9 @@ z - + - + @@ -2494,9 +2700,9 @@ z - + - + @@ -2504,9 +2710,9 @@ z - + - + @@ -2514,269 +2720,349 @@ z - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - + + + + + + + + + + + + - - - + + + + + + + + + + + + + + + + + + + + + + + - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - + + + - - + + - - - + + + - + - - - + + + - - + + - - - + + + - - - + + + - - - + + + - - - + + + - - + + - - + + - - - + + + - - + + - - + + - - + + - - + + - - - + + + - - + + - + - - + + - + - - + + - - - + + + - - + + - - - + + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - - + + + - - - + + + - + - - - + + + - + - - - + + + - - - + + + - - - + + + - - - + + + - + - + @@ -2784,9 +3070,9 @@ z - + - + @@ -2794,9 +3080,9 @@ z - + - + @@ -2804,9 +3090,9 @@ z - + - + @@ -2814,9 +3100,9 @@ z - + - + @@ -2824,9 +3110,9 @@ z - + - + @@ -2834,29 +3120,29 @@ z - - - + + + - - + + - - - + + + - - - + + + - + - + @@ -2864,9 +3150,9 @@ z - + - + @@ -2874,9 +3160,9 @@ z - + - + @@ -2884,9 +3170,9 @@ z - + - + @@ -2894,9 +3180,9 @@ z - + - + @@ -2904,9 +3190,9 @@ z - + - + @@ -2914,29 +3200,29 @@ z - - - + + + - - - + + + - - - + + + - - - + + + - + - + @@ -2944,9 +3230,9 @@ z - + - + @@ -2954,9 +3240,9 @@ z - + - + @@ -2964,9 +3250,9 @@ z - + - + @@ -2974,9 +3260,9 @@ z - + - + @@ -2984,9 +3270,9 @@ z - + - + @@ -2994,29 +3280,109 @@ z - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - - + + - + - + @@ -3024,9 +3390,9 @@ z - + - + @@ -3034,9 +3400,9 @@ z - + - + @@ -3044,9 +3410,9 @@ z - + - + @@ -3054,9 +3420,9 @@ z - + - + @@ -3064,9 +3430,9 @@ z - + - + @@ -3074,30 +3440,30 @@ z - - - + + + - + - - - + + + - - - + + + - + diff --git a/docs/images/starplot.png b/docs/images/starplot.png index 5f600e9c..e2a5bd73 100644 Binary files a/docs/images/starplot.png and b/docs/images/starplot.png differ diff --git a/pdm.lock b/pdm.lock index eab0992b..21faa341 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "dev", "docs", "lint", "test", "typecheck", "vision"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:cc23a9652ade7a78ab86d2663e963e4d41f1c30cdb485f6459f4fefc9fcea7e0" +content_hash = "sha256:b8df35bf60e5573e36c31c4ad4f324d7693f16b31cadcd27e48b352ae6c0235b" [[metadata.targets]] requires_python = ">=3.10" @@ -1405,18 +1405,18 @@ files = [ [[package]] name = "nibabel" -version = "5.2.1" -requires_python = ">=3.8" +version = "4.0.2" +requires_python = ">=3.7" summary = "Access a multitude of neuroimaging data formats" groups = ["all", "vision"] dependencies = [ - "importlib-resources>=1.3; python_version < \"3.9\"", - "numpy>=1.20", - "packaging>=17", + "numpy>=1.17", + "packaging>=17.0", + "setuptools", ] files = [ - {file = "nibabel-5.2.1-py3-none-any.whl", hash = "sha256:2cbbc22985f7f9d39d050df47249771dfb8d48447f5e7a993177e4cabfe047f0"}, - {file = "nibabel-5.2.1.tar.gz", hash = "sha256:b6c80b2e728e4bc2b65f1142d9b8d2287a9102a8bf8477e115ef0d8334559975"}, + {file = "nibabel-4.0.2-py3-none-any.whl", hash = "sha256:c4fe76348aa865f8300beaaf2a69d31624964c861853ef80c06e33d5f244413c"}, + {file = "nibabel-4.0.2.tar.gz", hash = "sha256:45c49b5349351b45f6c045a91aa02b4f0d367686ff3284632ef95ac65b930786"}, ] [[package]] @@ -1994,7 +1994,7 @@ name = "pyreadline3" version = "3.4.1" summary = "A python implementation of GNU readline." groups = ["default"] -marker = "sys_platform == \"win32\"" +marker = "sys_platform == \"win32\" and python_version >= \"3.8\"" files = [ {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, @@ -2465,7 +2465,7 @@ name = "setuptools" version = "75.1.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" -groups = ["default", "dev", "docs"] +groups = ["default", "all", "dev", "docs", "vision"] files = [ {file = "setuptools-75.1.0-py3-none-any.whl", hash = "sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2"}, {file = "setuptools-75.1.0.tar.gz", hash = "sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538"}, diff --git a/pyproject.toml b/pyproject.toml index a36af45a..7c52ba40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "pdm.backend" [project] name = "kaiko-eva" -version = "0.1.2" +version = "0.1.5" description = "Evaluation Framework for oncology foundation models." keywords = [ "machine-learning", @@ -34,14 +34,14 @@ maintainers = [ ] requires-python = ">=3.10" dependencies = [ - "torch==2.3.0", - "lightning>=2.2.2", - "jsonargparse[omegaconf]==4.31.0", + "torch>=2.3.0", + "lightning>=2.2.0", + "jsonargparse[omegaconf]>=4.30.0", "tensorboard>=2.16.2", "loguru>=0.7.2", - "pandas>=2.2.0", + "pandas>=2.0.0", "transformers>=4.38.2", - "onnxruntime>=1.17.1", + "onnxruntime>=1.15.1", "onnx>=1.16.0", "toolz>=0.12.1", "rich>=13.7.1", @@ -59,7 +59,7 @@ file = "LICENSE" [project.optional-dependencies] vision = [ "h5py>=3.10.0", - "nibabel>=5.2.0", + "nibabel>=4.0.1", "opencv-python-headless>=4.9.0.80", "timm>=1.0.9", "torchvision>=0.17.0", @@ -72,7 +72,7 @@ vision = [ ] all = [ "h5py>=3.10.0", - "nibabel>=5.2.0", + "nibabel>=4.0.1", "opencv-python-headless>=4.9.0.80", "timm>=1.0.9", "torchvision>=0.17.0", diff --git a/src/eva/core/models/wrappers/__init__.py b/src/eva/core/models/wrappers/__init__.py index 95ab6101..979577bd 100644 --- a/src/eva/core/models/wrappers/__init__.py +++ b/src/eva/core/models/wrappers/__init__.py @@ -2,12 +2,14 @@ from eva.core.models.wrappers.base import BaseModel from eva.core.models.wrappers.from_function import ModelFromFunction +from eva.core.models.wrappers.from_torchhub import TorchHubModel from eva.core.models.wrappers.huggingface import HuggingFaceModel from eva.core.models.wrappers.onnx import ONNXModel __all__ = [ "BaseModel", - "ModelFromFunction", "HuggingFaceModel", + "ModelFromFunction", "ONNXModel", + "TorchHubModel", ] diff --git a/src/eva/core/models/wrappers/from_torchhub.py b/src/eva/core/models/wrappers/from_torchhub.py new file mode 100644 index 00000000..cb424d01 --- /dev/null +++ b/src/eva/core/models/wrappers/from_torchhub.py @@ -0,0 +1,87 @@ +"""Model wrapper for torch.hub models.""" + +from typing import Any, Callable, Dict, Tuple + +import torch +import torch.nn as nn +from typing_extensions import override + +from eva.core.models import wrappers +from eva.core.models.wrappers import _utils + + +class TorchHubModel(wrappers.BaseModel): + """Model wrapper for `torch.hub` models.""" + + def __init__( + self, + model_name: str, + repo_or_dir: str, + pretrained: bool = True, + checkpoint_path: str = "", + out_indices: int | Tuple[int, ...] | None = None, + norm: bool = False, + trust_repo: bool = True, + model_kwargs: Dict[str, Any] | None = None, + tensor_transforms: Callable | None = None, + ) -> None: + """Initializes the encoder. + + Args: + model_name: Name of model to instantiate. + repo_or_dir: The torch.hub repository or local directory to load the model from. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + out_indices: Returns last n blocks if `int`, all if `None`, select + matching indices if sequence. + norm: Wether to apply norm layer to all intermediate features. Only + used when `out_indices` is not `None`. + trust_repo: If set to `False`, a prompt will ask the user whether the + repo should be trusted. + model_kwargs: Extra model arguments. + tensor_transforms: The transforms to apply to the output tensor + produced by the model. + """ + super().__init__(tensor_transforms=tensor_transforms) + + self._model_name = model_name + self._repo_or_dir = repo_or_dir + self._pretrained = pretrained + self._checkpoint_path = checkpoint_path + self._out_indices = out_indices + self._norm = norm + self._trust_repo = trust_repo + self._model_kwargs = model_kwargs or {} + + self.load_model() + + @override + def load_model(self) -> None: + """Builds and loads the torch.hub model.""" + self._model: nn.Module = torch.hub.load( + repo_or_dir=self._repo_or_dir, + model=self._model_name, + trust_repo=self._trust_repo, + pretrained=self._pretrained, + **self._model_kwargs, + ) # type: ignore + + if self._checkpoint_path: + _utils.load_model_weights(self._model, self._checkpoint_path) + + TorchHubModel.__name__ = self._model_name + + @override + def model_forward(self, tensor: torch.Tensor) -> torch.Tensor: + if self._out_indices is not None: + if not hasattr(self._model, "get_intermediate_layers"): + raise ValueError( + "Only models with `get_intermediate_layers` are supported " + "when using `out_indices`." + ) + + return self._model.get_intermediate_layers( + tensor, self._out_indices, reshape=True, return_class_token=False, norm=self._norm + ) + + return self._model(tensor) diff --git a/src/eva/vision/models/networks/backbones/__init__.py b/src/eva/vision/models/networks/backbones/__init__.py index 0fdf2963..1ef7bc85 100644 --- a/src/eva/vision/models/networks/backbones/__init__.py +++ b/src/eva/vision/models/networks/backbones/__init__.py @@ -1,6 +1,6 @@ """Vision Model Backbones API.""" -from eva.vision.models.networks.backbones import pathology, timm, universal +from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model -__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"] +__all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/__init__.py b/src/eva/vision/models/networks/backbones/torchhub/__init__.py new file mode 100644 index 00000000..6acd9797 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/__init__.py @@ -0,0 +1,5 @@ +"""torch.hub backbones API.""" + +from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model + +__all__ = ["torch_hub_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/backbones.py b/src/eva/vision/models/networks/backbones/torchhub/backbones.py new file mode 100644 index 00000000..d1503a80 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/backbones.py @@ -0,0 +1,61 @@ +"""torch.hub backbones.""" + +import functools +from typing import Tuple + +import torch +from loguru import logger +from torch import nn + +from eva.core.models import wrappers +from eva.vision.models.networks.backbones.registry import BackboneModelRegistry + +HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"] +"""List of torch.hub repositories for which to add the models to the registry.""" + + +def torch_hub_model( + model_name: str, + repo_or_dir: str, + checkpoint_path: str | None = None, + pretrained: bool = False, + out_indices: int | Tuple[int, ...] | None = None, + **kwargs, +) -> nn.Module: + """Initializes any ViT model from torch.hub with weights from a specified checkpoint. + + Args: + model_name: The name of the model to load. + repo_or_dir: The torch.hub repository or local directory to load the model from. + checkpoint_path: The path to the checkpoint file. + pretrained: If set to `True`, load pretrained model weights if available. + out_indices: Whether and which multi-level patch embeddings to return. + **kwargs: Additional arguments to pass to the model + + Returns: + The VIT model instance. + """ + logger.info( + f"Loading torch.hub model {model_name} from {repo_or_dir}" + + (f"using checkpoint {checkpoint_path}" if checkpoint_path else "") + ) + + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + pretrained=pretrained, + checkpoint_path=checkpoint_path or "", + out_indices=out_indices, + model_kwargs=kwargs, + ) + + +BackboneModelRegistry._registry.update( + { + f"torchhub/{repo}:{model_name}": functools.partial( + torch_hub_model, model_name=model_name, repo_or_dir=repo + ) + for repo in HUB_REPOS + for model_name in torch.hub.list(repo, verbose=False) + } +) diff --git a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py index c43b351c..ce242713 100644 --- a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py +++ b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py @@ -52,7 +52,7 @@ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torc """ if isinstance(features, torch.Tensor): features = [features] - if not isinstance(features, list) or features[0].ndim != 4: + if not isinstance(features, (list, tuple)) or features[0].ndim != 4: raise ValueError( "Input features should be a list of four (4) dimensional inputs of " "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." diff --git a/src/eva/vision/models/wrappers/__init__.py b/src/eva/vision/models/wrappers/__init__.py index 14d63b68..d2f84de4 100644 --- a/src/eva/vision/models/wrappers/__init__.py +++ b/src/eva/vision/models/wrappers/__init__.py @@ -3,4 +3,4 @@ from eva.vision.models.wrappers.from_registry import ModelFromRegistry from eva.vision.models.wrappers.from_timm import TimmModel -__all__ = ["TimmModel", "ModelFromRegistry"] +__all__ = ["ModelFromRegistry", "TimmModel"] diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index e90f919d..49ca8fda 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -54,7 +54,6 @@ def save_array_as_nifti( dtype: The data type to save the image. """ nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype) # type: ignore - nifti_image.header.get_xyzt_units() nifti_image.to_filename(filename) diff --git a/tests/eva/core/models/wrappers/test_from_torchub.py b/tests/eva/core/models/wrappers/test_from_torchub.py new file mode 100644 index 00000000..bf275234 --- /dev/null +++ b/tests/eva/core/models/wrappers/test_from_torchub.py @@ -0,0 +1,76 @@ +"""TorchHubModel tests.""" + +from typing import Any, Dict, Tuple + +import pytest +import torch + +from eva.core.models import wrappers + + +@pytest.mark.parametrize( + "model_name, repo_or_dir, out_indices, model_kwargs, " + "input_tensor, expected_len, expected_shape", + [ + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + None, + None, + torch.Tensor(2, 3, 224, 224), + None, + torch.Size([2, 384]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 1, + None, + torch.Tensor(2, 3, 224, 224), + 1, + torch.Size([2, 384, 16, 16]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 3, + None, + torch.Tensor(2, 3, 224, 224), + 3, + torch.Size([2, 384, 16, 16]), + ), + ], +) +def test_torchhub_model( + torchhub_model: wrappers.TorchHubModel, + input_tensor: torch.Tensor, + expected_len: int | None, + expected_shape: torch.Size, +) -> None: + """Tests the torch.hub model wrapper.""" + outputs = torchhub_model(input_tensor) + if torchhub_model._out_indices is not None: + assert isinstance(outputs, list) or isinstance(outputs, tuple) + assert len(outputs) == expected_len + assert isinstance(outputs[0], torch.Tensor) + assert outputs[0].shape == expected_shape + else: + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == expected_shape + + +@pytest.fixture(scope="function") +def torchhub_model( + model_name: str, + repo_or_dir: str, + out_indices: int | Tuple[int, ...] | None, + model_kwargs: Dict[str, Any] | None, +) -> wrappers.TorchHubModel: + """TorchHubModel fixture.""" + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + out_indices=out_indices, + model_kwargs=model_kwargs, + pretrained=False, + ) diff --git a/tools/data/leaderboard.csv b/tools/data/leaderboard.csv index f30a3a33..678d4bca 100644 --- a/tools/data/leaderboard.csv +++ b/tools/data/leaderboard.csv @@ -1,12 +1,14 @@ bach,crc,mhist,patch_camelyon,camelyon16_small,panda_small,consep,monusac,model -0.783,0.94,0.773,0.901,0.767,0.625,0.63,0.537,dino_vits16_lunit -0.722,0.936,0.799,0.922,0.797,0.64,0.68,0.54,owkin_phikon -0.797,0.947,0.844,0.936,0.834,0.656,0.662,0.554,dino_vitl16_uni -0.758,0.958,0.839,0.942,0.82,0.645,0.69,0.588,bioptimus_h_optimus_0 -0.761,0.952,0.829,0.945,0.814,0.664,0.661,0.558,prov_gigapath -0.816,0.931,0.826,0.951,0.832,0.633,0.69,0.586,histai_hibou_l -0.802,0.938,0.829,0.904,0.789,0.618,0.611,0.549,dino_vits16_kaiko -0.829,0.952,0.814,0.885,0.814,0.654,0.688,0.599,dino_vits8_kaiko -0.835,0.958,0.835,0.907,0.816,0.621,0.636,0.551,dino_vitb16_kaiko -0.858,0.957,0.823,0.918,0.818,0.638,0.703,0.641,dino_vitb8_kaiko -0.864,0.936,0.828,0.908,0.812,0.65,0.679,0.59,dino_vitl14_kaiko +0.88,0.966,0.858,0.936,0.864,0.642,0.723,0.713,paige_virchow2 +0.758,0.958,0.839,0.942,0.82,0.645,0.726,0.725,bioptimus_h_optimus_0 +0.797,0.947,0.844,0.936,0.834,0.656,0.711,0.708,dino_vitl16_uni +0.761,0.952,0.829,0.945,0.814,0.664,0.709,0.724,prov_gigapath +0.816,0.931,0.826,0.951,0.832,0.633,0.725,0.728,histai_hibou_l +0.858,0.957,0.823,0.918,0.818,0.638,0.723,0.736,dino_vitb8_kaiko +0.864,0.936,0.828,0.908,0.812,0.65,0.716,0.727,dino_vitl14_kaiko +0.829,0.952,0.814,0.885,0.814,0.654,0.716,0.712,dino_vits8_kaiko +0.835,0.958,0.835,0.907,0.816,0.621,0.69,0.69,dino_vitb16_kaiko +0.722,0.936,0.799,0.922,0.797,0.64,0.708,0.709,owkin_phikon +0.802,0.938,0.829,0.904,0.789,0.618,0.683,0.694,dino_vits16_kaiko +0.727,0.939,0.775,0.893,0.808,0.635,0.711,0.689,owkin_phikon_v2 +0.783,0.94,0.773,0.901,0.767,0.625,0.68,0.69,dino_vits16_lunit \ No newline at end of file diff --git a/tools/generate_leaderboard_plots.py b/tools/generate_leaderboard_plots.py index 3468fa76..5e1f02f9 100644 --- a/tools/generate_leaderboard_plots.py +++ b/tools/generate_leaderboard_plots.py @@ -28,8 +28,10 @@ "monusac": "GeneralizedDiceScore", } _fm_name_map = { - "dino_vits16_lunit": "Lunit - ViT-S16 | TCGA", - "owkin_phikon": "Owkin (Phikon) - iBOT ViT-B16 | TCGA", + "paige_virchow2": "Virchow2 - DINOv2 ViT-H14 | 3.1M slides", + "dino_vits16_lunit": "Lunit - DINO ViT-S16 | TCGA", + "owkin_phikon": "Phikon - iBOT ViT-B16 | TCGA", + "owkin_phikon_v2": "Phikon-v2 - DINOv2 ViT-L16 | PANCAN-XL", "dino_vitl16_uni": "UNI - DINOv2 ViT-L16 | Mass-100k", "bioptimus_h_optimus_0": "H-optimus-0 - ViT-G14 | 500k slides", "prov_gigapath": "Prov-GigaPath - DINOv2 ViT-G14 | 181k slides",