diff --git a/configs/vision/pathology/offline/classification/bach.yaml b/configs/vision/pathology/offline/classification/bach.yaml index bf9494ba5..98ac45cc1 100644 --- a/configs/vision/pathology/offline/classification/bach.yaml +++ b/configs/vision/pathology/offline/classification/bach.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bach} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 400 + patience: ${oc.env:PATIENCE, 400} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/classification/camelyon16.yaml b/configs/vision/pathology/offline/classification/camelyon16.yaml index 5a5f9a5c4..97ef1f6e5 100644 --- a/configs/vision/pathology/offline/classification/camelyon16.yaml +++ b/configs/vision/pathology/offline/classification/camelyon16.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/camelyon16} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/camelyon16_small.yaml b/configs/vision/pathology/offline/classification/camelyon16_small.yaml index dd4be2af6..d3346dee4 100644 --- a/configs/vision/pathology/offline/classification/camelyon16_small.yaml +++ b/configs/vision/pathology/offline/classification/camelyon16_small.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/camelyon16} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/crc.yaml b/configs/vision/pathology/offline/classification/crc.yaml index feca261e2..e54b095b3 100644 --- a/configs/vision/pathology/offline/classification/crc.yaml +++ b/configs/vision/pathology/offline/classification/crc.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/crc} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 24 + patience: ${oc.env:PATIENCE, 24} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/classification/mhist.yaml b/configs/vision/pathology/offline/classification/mhist.yaml index f96c1f151..ad7b5c36e 100644 --- a/configs/vision/pathology/offline/classification/mhist.yaml +++ b/configs/vision/pathology/offline/classification/mhist.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/mhist} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 70 + patience: ${oc.env:PATIENCE, 70} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/classification/panda.yaml b/configs/vision/pathology/offline/classification/panda.yaml index b88138c58..0ef3c3a4f 100644 --- a/configs/vision/pathology/offline/classification/panda.yaml +++ b/configs/vision/pathology/offline/classification/panda.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/panda} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 49} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/panda_small.yaml b/configs/vision/pathology/offline/classification/panda_small.yaml index 53735a7cc..e4a4980ef 100644 --- a/configs/vision/pathology/offline/classification/panda_small.yaml +++ b/configs/vision/pathology/offline/classification/panda_small.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/panda} max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 49} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/classification/patch_camelyon.yaml b/configs/vision/pathology/offline/classification/patch_camelyon.yaml index fc8450e79..4dfbd34fd 100644 --- a/configs/vision/pathology/offline/classification/patch_camelyon.yaml +++ b/configs/vision/pathology/offline/classification/patch_camelyon.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/patch_camelyon} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -23,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 9 + patience: ${oc.env:PATIENCE, 9} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/pathology/offline/segmentation/bcss.yaml b/configs/vision/pathology/offline/segmentation/bcss.yaml index b7c0f6165..9ea0f1a65 100644 --- a/configs/vision/pathology/offline/segmentation/bcss.yaml +++ b/configs/vision/pathology/offline/segmentation/bcss.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 8 + patience: ${oc.env:PATIENCE, 8} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter @@ -65,7 +66,7 @@ model: optimizer: class_path: torch.optim.AdamW init_args: - lr: ${oc.env:LR_VALUE, 0.0001} + lr: ${oc.env:LR_VALUE, 0.002} lr_scheduler: class_path: torch.optim.lr_scheduler.PolynomialLR init_args: diff --git a/configs/vision/pathology/offline/segmentation/consep.yaml b/configs/vision/pathology/offline/segmentation/consep.yaml index 79af29627..6ceb085ce 100644 --- a/configs/vision/pathology/offline/segmentation/consep.yaml +++ b/configs/vision/pathology/offline/segmentation/consep.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 34 + patience: ${oc.env:PATIENCE, 34} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter @@ -65,7 +66,7 @@ model: optimizer: class_path: torch.optim.AdamW init_args: - lr: ${oc.env:LR_VALUE, 0.0001} + lr: ${oc.env:LR_VALUE, 0.002} lr_scheduler: class_path: torch.optim.lr_scheduler.PolynomialLR init_args: diff --git a/configs/vision/pathology/offline/segmentation/monusac.yaml b/configs/vision/pathology/offline/segmentation/monusac.yaml index 587f99846..b89d4eb62 100644 --- a/configs/vision/pathology/offline/segmentation/monusac.yaml +++ b/configs/vision/pathology/offline/segmentation/monusac.yaml @@ -6,6 +6,7 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -28,7 +29,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 50 + patience: ${oc.env:PATIENCE, 50} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter @@ -67,7 +68,7 @@ model: optimizer: class_path: torch.optim.AdamW init_args: - lr: ${oc.env:LR_VALUE, 0.0001} + lr: ${oc.env:LR_VALUE, 0.002} lr_scheduler: class_path: torch.optim.lr_scheduler.PolynomialLR init_args: diff --git a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml index 250f18d69..dd34a78c9 100644 --- a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml @@ -5,6 +5,7 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 5 + patience: ${oc.env:PATIENCE, 5} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/pathology/online/classification/bach.yaml b/configs/vision/pathology/online/classification/bach.yaml index 1719d821e..6e6f9bb88 100644 --- a/configs/vision/pathology/online/classification/bach.yaml +++ b/configs/vision/pathology/online/classification/bach.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/bach} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 400 + patience: ${oc.env:PATIENCE, 400} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/classification/crc.yaml b/configs/vision/pathology/online/classification/crc.yaml index 5abe659e0..37bb52c61 100644 --- a/configs/vision/pathology/online/classification/crc.yaml +++ b/configs/vision/pathology/online/classification/crc.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/crc} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 24 + patience: ${oc.env:PATIENCE, 24} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/classification/mhist.yaml b/configs/vision/pathology/online/classification/mhist.yaml index 25dcbc509..b2a23f13b 100644 --- a/configs/vision/pathology/online/classification/mhist.yaml +++ b/configs/vision/pathology/online/classification/mhist.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &LIGHTNING_ROOT ${oc.env:LIGHTNING_ROOT, logs/dino_vits16/online/mhist} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 70 + patience: ${oc.env:PATIENCE, 70} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/classification/patch_camelyon.yaml b/configs/vision/pathology/online/classification/patch_camelyon.yaml index 13817a718..60800129c 100644 --- a/configs/vision/pathology/online/classification/patch_camelyon.yaml +++ b/configs/vision/pathology/online/classification/patch_camelyon.yaml @@ -2,8 +2,10 @@ trainer: class_path: eva.Trainer init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/online/patch_camelyon} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -22,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 9 + patience: ${oc.env:PATIENCE, 9} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/segmentation/bcss.yaml b/configs/vision/pathology/online/segmentation/bcss.yaml index 2c343f134..694936df7 100644 --- a/configs/vision/pathology/online/segmentation/bcss.yaml +++ b/configs/vision/pathology/online/segmentation/bcss.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/pathology/online/segmentation/consep.yaml b/configs/vision/pathology/online/segmentation/consep.yaml index 06f181df8..e17fcae21 100644 --- a/configs/vision/pathology/online/segmentation/consep.yaml +++ b/configs/vision/pathology/online/segmentation/consep.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/consep} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 34} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: @@ -45,10 +46,10 @@ model: out_indices: ${oc.env:OUT_INDICES, 1} model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} decoder: - class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage init_args: in_features: ${oc.env:IN_FEATURES, 384} - num_classes: &NUM_CLASSES 5 + num_classes: &NUM_CLASSES 5 criterion: class_path: eva.vision.losses.DiceLoss init_args: diff --git a/configs/vision/pathology/online/segmentation/monusac.yaml b/configs/vision/pathology/online/segmentation/monusac.yaml index b7f7ec21c..6b0e9a508 100644 --- a/configs/vision/pathology/online/segmentation/monusac.yaml +++ b/configs/vision/pathology/online/segmentation/monusac.yaml @@ -2,10 +2,11 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} - default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/monusac} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 550} - log_every_n_steps: 4 + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} + log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 50} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: @@ -45,10 +46,10 @@ model: out_indices: ${oc.env:OUT_INDICES, 1} model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} decoder: - class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage init_args: in_features: ${oc.env:IN_FEATURES, 384} - num_classes: &NUM_CLASSES 5 + num_classes: &NUM_CLASSES 5 criterion: class_path: eva.vision.losses.DiceLoss init_args: diff --git a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml index 8f584f50d..b3342006d 100644 --- a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml @@ -2,9 +2,10 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -25,7 +26,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 5 + patience: ${oc.env:PATIENCE, 5} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/radiology/offline/segmentation/lits.yaml b/configs/vision/radiology/offline/segmentation/lits.yaml index d9e0c4903..1c4cfe498 100644 --- a/configs/vision/radiology/offline/segmentation/lits.yaml +++ b/configs/vision/radiology/offline/segmentation/lits.yaml @@ -2,7 +2,7 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} callbacks: @@ -24,7 +24,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml index a0059e34b..866a70333 100644 --- a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -2,9 +2,10 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -24,7 +25,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE - class_path: eva.callbacks.SegmentationEmbeddingsWriter diff --git a/configs/vision/radiology/online/segmentation/lits.yaml b/configs/vision/radiology/online/segmentation/lits.yaml index 3d8d2fc57..81ed0e2f7 100644 --- a/configs/vision/radiology/online/segmentation/lits.yaml +++ b/configs/vision/radiology/online/segmentation/lits.yaml @@ -6,6 +6,7 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml index cff4c88e8..b5224d6cb 100644 --- a/configs/vision/radiology/online/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -6,6 +6,7 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} log_every_n_steps: 6 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar @@ -26,7 +27,7 @@ trainer: - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: min_delta: 0 - patience: 100 + patience: ${oc.env:PATIENCE, 100} monitor: *MONITOR_METRIC mode: *MONITOR_METRIC_MODE logger: diff --git a/configs/vision/tests/offline/panda.yaml b/configs/vision/tests/offline/panda.yaml index 4051b4edf..6bd0e958e 100644 --- a/configs/vision/tests/offline/panda.yaml +++ b/configs/vision/tests/offline/panda.yaml @@ -6,6 +6,7 @@ trainer: max_epochs: &MAX_EPOCHS 1 limit_train_batches: 2 limit_val_batches: 2 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} callbacks: - class_path: eva.callbacks.ClassificationEmbeddingsWriter init_args: diff --git a/configs/vision/tests/offline/patch_camelyon.yaml b/configs/vision/tests/offline/patch_camelyon.yaml index e09a44c6b..f17a24ad9 100644 --- a/configs/vision/tests/offline/patch_camelyon.yaml +++ b/configs/vision/tests/offline/patch_camelyon.yaml @@ -6,6 +6,7 @@ trainer: max_epochs: &MAX_EPOCHS 1 limit_train_batches: 2 limit_val_batches: 2 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, last} callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: eva.callbacks.ClassificationEmbeddingsWriter diff --git a/configs/vision/tests/online/patch_camelyon.yaml b/configs/vision/tests/online/patch_camelyon.yaml index 4b8709415..cf75e1888 100644 --- a/configs/vision/tests/online/patch_camelyon.yaml +++ b/configs/vision/tests/online/patch_camelyon.yaml @@ -6,6 +6,7 @@ trainer: max_epochs: &MAX_EPOCHS 1 limit_train_batches: 2 limit_val_batches: 2 + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} model: class_path: eva.HeadModule init_args: diff --git a/docs/images/leaderboard.svg b/docs/images/leaderboard.svg index 2031c979a..f447a356f 100644 --- a/docs/images/leaderboard.svg +++ b/docs/images/leaderboard.svg @@ -6,7 +6,7 @@ - 2024-10-18T15:48:36.884888 + 2024-11-21T11:04:41.708790 image/svg+xml @@ -40,532 +40,628 @@ z +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #3a3aff"/> +" clip-path="url(#pe67134274c)" style="fill: #0000ff"/> +" clip-path="url(#pe67134274c)" style="fill: #7a7aff"/> +" clip-path="url(#pe67134274c)" style="fill: #1010ff"/> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + - + - + - + - + +L 631.971875 300.838362 +" clip-path="url(#pe67134274c)" style="fill: #fafaff"/> @@ -1106,30 +1202,26 @@ z - - + - - + + - - + + + + @@ -1181,53 +1363,33 @@ Q 359 3434 948 4092 Q 1538 4750 2522 4750 z " transform="scale(0.015625)"/> - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - + - + - + @@ -1823,15 +1850,163 @@ z + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - + + @@ -1852,27 +2027,27 @@ z - - - - - - - - - + + + + + + + + + - + - + @@ -1911,12 +2086,53 @@ z - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + @@ -1950,32 +2166,16 @@ z - - + + - + - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - + - + - + @@ -2103,15 +2269,82 @@ z - - + + - + - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -2120,37 +2353,42 @@ z - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - - - + + + - - + + - - - + + + - - - + + - - - - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + + + + - + - - - + + + - - - + + + - + - + @@ -2304,9 +2510,9 @@ z - + - + @@ -2314,9 +2520,9 @@ z - + - + @@ -2324,9 +2530,9 @@ z - + - + @@ -2334,9 +2540,9 @@ z - + - + @@ -2344,9 +2550,9 @@ z - + - + @@ -2354,109 +2560,109 @@ z - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - + + - - - + + + - + - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - - + + + - - - + + + - - - + + + - + - + @@ -2464,9 +2670,9 @@ z - + - + @@ -2474,9 +2680,9 @@ z - + - + @@ -2484,9 +2690,9 @@ z - + - + @@ -2494,9 +2700,9 @@ z - + - + @@ -2504,9 +2710,9 @@ z - + - + @@ -2514,269 +2720,349 @@ z - - - + + + - - - + + + - - - + + + - - - + + + - - - + + + - - + + + + + + + + + + + + - - - + + + + + + + + + + + + + + + + + + + + + + + - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - + + + - - + + - - - + + + - + - - - + + + - - + + - - - + + + - - - + + + - - - + + + - - - + + + - - + + - - + + - - - + + + - - + + - - + + - - + + - - + + - - - + + + - - + + - + - - + + - + - - + + - - - + + + - - + + - - - + + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - + + - - - + + + - - - + + + - - - + + + - + - - - + + + - + - - - + + + - - - + + + - - - + + + - - - + + + - + - + @@ -2784,9 +3070,9 @@ z - + - + @@ -2794,9 +3080,9 @@ z - + - + @@ -2804,9 +3090,9 @@ z - + - + @@ -2814,9 +3100,9 @@ z - + - + @@ -2824,9 +3110,9 @@ z - + - + @@ -2834,29 +3120,29 @@ z - - - + + + - - + + - - - + + + - - - + + + - + - + @@ -2864,9 +3150,9 @@ z - + - + @@ -2874,9 +3160,9 @@ z - + - + @@ -2884,9 +3170,9 @@ z - + - + @@ -2894,9 +3180,9 @@ z - + - + @@ -2904,9 +3190,9 @@ z - + - + @@ -2914,29 +3200,29 @@ z - - - + + + - - - + + + - - - + + + - - - + + + - + - + @@ -2944,9 +3230,9 @@ z - + - + @@ -2954,9 +3240,9 @@ z - + - + @@ -2964,9 +3250,9 @@ z - + - + @@ -2974,9 +3260,9 @@ z - + - + @@ -2984,9 +3270,9 @@ z - + - + @@ -2994,29 +3280,109 @@ z - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - - + + - + - + @@ -3024,9 +3390,9 @@ z - + - + @@ -3034,9 +3400,9 @@ z - + - + @@ -3044,9 +3410,9 @@ z - + - + @@ -3054,9 +3420,9 @@ z - + - + @@ -3064,9 +3430,9 @@ z - + - + @@ -3074,30 +3440,30 @@ z - - - + + + - + - - - + + + - - - + + + - + diff --git a/docs/images/starplot.png b/docs/images/starplot.png index 5f600e9cd..e2a5bd73b 100644 Binary files a/docs/images/starplot.png and b/docs/images/starplot.png differ diff --git a/docs/leaderboards.md b/docs/leaderboards.md index 66e53f4ee..c0570c15c 100644 --- a/docs/leaderboards.md +++ b/docs/leaderboards.md @@ -40,7 +40,7 @@ We selected this approach to prioritize reliable, robust and fair FM-evaluation | **Output activation function** | none | none | none | | **Number of steps** | 12,500 | 12,500 (1) | 2,000 | | **Base batch size** | 256 | 32 | 64 | -| **Base learning rate** | 0.0003 | 0.001 | 0.0001 | +| **Base learning rate** | 0.0003 | 0.001 | 0.002 | | **Early stopping** | 5% * [Max epochs] | 10% * [Max epochs] (2) | 10% * [Max epochs] (2) | | **Optimizer** | SGD | AdamW | AdamW | | **Momentum** | 0.9 | n/a | n/a | diff --git a/docs/user-guide/getting-started/how_to_use.md b/docs/user-guide/getting-started/how_to_use.md index cea6952d5..34288f416 100644 --- a/docs/user-guide/getting-started/how_to_use.md +++ b/docs/user-guide/getting-started/how_to_use.md @@ -59,5 +59,7 @@ To customize runs, without the need of creating custom config-files, you can ove | `MONITOR_METRIC_MODE` | `str` | "min" or "max", depending on the `MONITOR_METRIC` used | | `REPO_OR_DIR` | `str` | GitHub repo with format containing model implementation, e.g. "facebookresearch/dino:main" | | `TQDM_REFRESH_RATE` | `str` | Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the progress bar. | -| `N_DATA_WORKERS` | `str` | How many subprocesses to use for the torch dataloaders. Set to `null` to use the number of cpu cores. | -| `METRICS_DEVICE` | `str` | Specifies the device on which to compute the metrics. If not set, will use the same device as used for training. | \ No newline at end of file +| `N_DATA_WORKERS` | `str` | How many subprocesses to use for the torch dataloaders. Set to `null` to use the number of cpu cores. | +| `METRICS_DEVICE` | `str` | Specifies the device on which to compute the metrics. If not set, will use the same device as used for training. | +| `CHECKPOINT_TYPE` | `str` | Set to "best" or "last", to select which checkpoint to load for evaluations on validation & test sets after training. | +| `PATIENCE` | `int` | Number of checks with no improvement after which training will be stopped (early stopping). | \ No newline at end of file diff --git a/pdm.lock b/pdm.lock index eab0992bb..21faa341d 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "dev", "docs", "lint", "test", "typecheck", "vision"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:cc23a9652ade7a78ab86d2663e963e4d41f1c30cdb485f6459f4fefc9fcea7e0" +content_hash = "sha256:b8df35bf60e5573e36c31c4ad4f324d7693f16b31cadcd27e48b352ae6c0235b" [[metadata.targets]] requires_python = ">=3.10" @@ -1405,18 +1405,18 @@ files = [ [[package]] name = "nibabel" -version = "5.2.1" -requires_python = ">=3.8" +version = "4.0.2" +requires_python = ">=3.7" summary = "Access a multitude of neuroimaging data formats" groups = ["all", "vision"] dependencies = [ - "importlib-resources>=1.3; python_version < \"3.9\"", - "numpy>=1.20", - "packaging>=17", + "numpy>=1.17", + "packaging>=17.0", + "setuptools", ] files = [ - {file = "nibabel-5.2.1-py3-none-any.whl", hash = "sha256:2cbbc22985f7f9d39d050df47249771dfb8d48447f5e7a993177e4cabfe047f0"}, - {file = "nibabel-5.2.1.tar.gz", hash = "sha256:b6c80b2e728e4bc2b65f1142d9b8d2287a9102a8bf8477e115ef0d8334559975"}, + {file = "nibabel-4.0.2-py3-none-any.whl", hash = "sha256:c4fe76348aa865f8300beaaf2a69d31624964c861853ef80c06e33d5f244413c"}, + {file = "nibabel-4.0.2.tar.gz", hash = "sha256:45c49b5349351b45f6c045a91aa02b4f0d367686ff3284632ef95ac65b930786"}, ] [[package]] @@ -1994,7 +1994,7 @@ name = "pyreadline3" version = "3.4.1" summary = "A python implementation of GNU readline." groups = ["default"] -marker = "sys_platform == \"win32\"" +marker = "sys_platform == \"win32\" and python_version >= \"3.8\"" files = [ {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, @@ -2465,7 +2465,7 @@ name = "setuptools" version = "75.1.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" -groups = ["default", "dev", "docs"] +groups = ["default", "all", "dev", "docs", "vision"] files = [ {file = "setuptools-75.1.0-py3-none-any.whl", hash = "sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2"}, {file = "setuptools-75.1.0.tar.gz", hash = "sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538"}, diff --git a/pyproject.toml b/pyproject.toml index a36af45a3..7c52ba407 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "pdm.backend" [project] name = "kaiko-eva" -version = "0.1.2" +version = "0.1.5" description = "Evaluation Framework for oncology foundation models." keywords = [ "machine-learning", @@ -34,14 +34,14 @@ maintainers = [ ] requires-python = ">=3.10" dependencies = [ - "torch==2.3.0", - "lightning>=2.2.2", - "jsonargparse[omegaconf]==4.31.0", + "torch>=2.3.0", + "lightning>=2.2.0", + "jsonargparse[omegaconf]>=4.30.0", "tensorboard>=2.16.2", "loguru>=0.7.2", - "pandas>=2.2.0", + "pandas>=2.0.0", "transformers>=4.38.2", - "onnxruntime>=1.17.1", + "onnxruntime>=1.15.1", "onnx>=1.16.0", "toolz>=0.12.1", "rich>=13.7.1", @@ -59,7 +59,7 @@ file = "LICENSE" [project.optional-dependencies] vision = [ "h5py>=3.10.0", - "nibabel>=5.2.0", + "nibabel>=4.0.1", "opencv-python-headless>=4.9.0.80", "timm>=1.0.9", "torchvision>=0.17.0", @@ -72,7 +72,7 @@ vision = [ ] all = [ "h5py>=3.10.0", - "nibabel>=5.2.0", + "nibabel>=4.0.1", "opencv-python-headless>=4.9.0.80", "timm>=1.0.9", "torchvision>=0.17.0", diff --git a/src/eva/core/data/dataloaders/dataloader.py b/src/eva/core/data/dataloaders/dataloader.py index 70040c618..65f0d6566 100644 --- a/src/eva/core/data/dataloaders/dataloader.py +++ b/src/eva/core/data/dataloaders/dataloader.py @@ -59,17 +59,20 @@ class DataLoader: prefetch_factor: int | None = 2 """Number of batches loaded in advance by each worker.""" - def __call__(self, dataset: datasets.TorchDataset) -> dataloader.DataLoader: + def __call__( + self, dataset: datasets.TorchDataset, sampler: samplers.Sampler | None = None + ) -> dataloader.DataLoader: """Returns the dataloader on the provided dataset. Args: dataset: dataset from which to load the data. + sampler: defines the strategy to draw samples from the dataset. """ return dataloader.DataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=self.shuffle, - sampler=self.sampler, + sampler=sampler or self.sampler, batch_sampler=self.batch_sampler, num_workers=self.num_workers or multiprocessing.cpu_count(), collate_fn=self.collate_fn, diff --git a/src/eva/core/data/datamodules/datamodule.py b/src/eva/core/data/datamodules/datamodule.py index 1f050ec75..c9522c227 100644 --- a/src/eva/core/data/datamodules/datamodule.py +++ b/src/eva/core/data/datamodules/datamodule.py @@ -8,6 +8,7 @@ from eva.core.data import dataloaders as dataloaders_lib from eva.core.data import datasets as datasets_lib +from eva.core.data import samplers as samplers_lib from eva.core.data.datamodules import call, schemas @@ -24,17 +25,20 @@ def __init__( self, datasets: schemas.DatasetsSchema | None = None, dataloaders: schemas.DataloadersSchema | None = None, + samplers: schemas.SamplersSchema | None = None, ) -> None: """Initializes the datamodule. Args: datasets: The desired datasets. dataloaders: The desired dataloaders. + samplers: The desired samplers for the dataloaders. """ super().__init__() self.datasets = datasets or self.default_datasets self.dataloaders = dataloaders or self.default_dataloaders + self.samplers = samplers or self.default_samplers @property def default_datasets(self) -> schemas.DatasetsSchema: @@ -46,6 +50,11 @@ def default_dataloaders(self) -> schemas.DataloadersSchema: """Returns the default dataloader schema.""" return schemas.DataloadersSchema() + @property + def default_samplers(self) -> schemas.SamplersSchema: + """Returns the default samplers schema.""" + return schemas.SamplersSchema() + @override def prepare_data(self) -> None: call.call_method_if_exists(self.datasets.tolist(), "prepare_data") @@ -64,7 +73,12 @@ def train_dataloader(self) -> TRAIN_DATALOADERS: raise ValueError( "Train dataloader can not be initialized as `self.datasets.train` is `None`." ) - return self.dataloaders.train(self.datasets.train) + if isinstance(self.datasets.train, list) and len(self.datasets.train) > 1: + raise ValueError("Train dataloader can not be initialized with multiple datasets.") + + return self._initialize_dataloaders( + self.dataloaders.train, self.datasets.train, self.samplers.train + )[0] @override def val_dataloader(self) -> EVAL_DATALOADERS: @@ -72,7 +86,9 @@ def val_dataloader(self) -> EVAL_DATALOADERS: raise ValueError( "Validation dataloader can not be initialized as `self.datasets.val` is `None`." ) - return self._initialize_dataloaders(self.dataloaders.val, self.datasets.val) + return self._initialize_dataloaders( + self.dataloaders.val, self.datasets.val, self.samplers.val + ) @override def test_dataloader(self) -> EVAL_DATALOADERS: @@ -80,7 +96,9 @@ def test_dataloader(self) -> EVAL_DATALOADERS: raise ValueError( "Test dataloader can not be initialized as `self.datasets.test` is `None`." ) - return self._initialize_dataloaders(self.dataloaders.test, self.datasets.test) + return self._initialize_dataloaders( + self.dataloaders.test, self.datasets.test, self.samplers.test + ) @override def predict_dataloader(self) -> EVAL_DATALOADERS: @@ -88,21 +106,40 @@ def predict_dataloader(self) -> EVAL_DATALOADERS: raise ValueError( "Predict dataloader can not be initialized as `self.datasets.predict` is `None`." ) - return self._initialize_dataloaders(self.dataloaders.predict, self.datasets.predict) + if isinstance(self.datasets.predict, list) and len(self.datasets.predict) > 1: + # Only apply sampler to the first predict dataset (should correspond to train split) + train_dataloader = self._initialize_dataloaders( + self.dataloaders.predict, self.datasets.predict[0], self.samplers.predict + ) + return train_dataloader + self._initialize_dataloaders( + self.dataloaders.predict, self.datasets.predict[1:] + ) + + return self._initialize_dataloaders( + self.dataloaders.predict, self.datasets.predict, self.samplers.predict + ) def _initialize_dataloaders( self, dataloader: dataloaders_lib.DataLoader, datasets: datasets_lib.TorchDataset | List[datasets_lib.TorchDataset], + sampler: samplers_lib.Sampler | None = None, ) -> EVAL_DATALOADERS: """Initializes dataloaders from a given set of dataset. Args: dataloader: The dataloader to apply to the provided datasets. datasets: The desired dataset(s) to allocate dataloader(s). + sampler: The sampler to use for the dataloader. Returns: A list with the dataloaders of the provided dataset(s). """ datasets = datasets if isinstance(datasets, list) else [datasets] - return list(map(dataloader, datasets)) + + dataloaders = [] + for dataset in datasets: + if sampler is not None and isinstance(sampler, samplers_lib.SamplerWithDataSource): + sampler.set_dataset(dataset) # type: ignore + dataloaders.append(dataloader(dataset, sampler=sampler)) + return dataloaders diff --git a/src/eva/core/data/datamodules/schemas.py b/src/eva/core/data/datamodules/schemas.py index 8780ac61d..d19b342e3 100644 --- a/src/eva/core/data/datamodules/schemas.py +++ b/src/eva/core/data/datamodules/schemas.py @@ -3,7 +3,7 @@ import dataclasses from typing import List -from eva.core.data import dataloaders, datasets +from eva.core.data import dataloaders, datasets, samplers TRAIN_DATASET = datasets.TorchDataset | None """Train dataset.""" @@ -60,3 +60,20 @@ class DataloadersSchema: predict: dataloaders.DataLoader = dataclasses.field(default_factory=dataloaders.DataLoader) """Predict dataloader.""" + + +@dataclasses.dataclass(frozen=True) +class SamplersSchema: + """Samplers schema used in DataModule.""" + + train: samplers.Sampler | None = None + """Train sampler.""" + + val: samplers.Sampler | None = None + """Validation sampler.""" + + test: samplers.Sampler | None = None + """Test sampler.""" + + predict: samplers.Sampler | None = None + """Predict sampler.""" diff --git a/src/eva/core/data/datasets/__init__.py b/src/eva/core/data/datasets/__init__.py index ba4da0cff..c5e366827 100644 --- a/src/eva/core/data/datasets/__init__.py +++ b/src/eva/core/data/datasets/__init__.py @@ -1,15 +1,18 @@ """Datasets API.""" -from eva.core.data.datasets.base import Dataset +from eva.core.data.datasets.base import Dataset, MapDataset from eva.core.data.datasets.classification import ( EmbeddingsClassificationDataset, MultiEmbeddingsClassificationDataset, ) from eva.core.data.datasets.dataset import TorchDataset +from eva.core.data.datasets.typings import DataSample __all__ = [ "Dataset", + "MapDataset", "EmbeddingsClassificationDataset", "MultiEmbeddingsClassificationDataset", "TorchDataset", + "DataSample", ] diff --git a/src/eva/core/data/datasets/base.py b/src/eva/core/data/datasets/base.py index d83fa74ed..a03eaf736 100644 --- a/src/eva/core/data/datasets/base.py +++ b/src/eva/core/data/datasets/base.py @@ -1,5 +1,7 @@ """Base dataset class.""" +import abc + from eva.core.data.datasets import dataset @@ -51,3 +53,24 @@ def teardown(self) -> None: of fit (train + validate), validate, test, or predict and it will be called from every process (i.e. GPU) across all the nodes in DDP. """ + + +class MapDataset(Dataset): + """Abstract base class for all map-style datasets.""" + + @abc.abstractmethod + def __getitem__(self, index: int): + """Retrieves the item at the given index. + + Args: + index: Index of the item to retrieve. + + Returns: + The data at the given index. + """ + raise NotImplementedError + + @abc.abstractmethod + def __len__(self) -> int: + """Returns the length of the dataset.""" + raise NotImplementedError diff --git a/src/eva/core/data/datasets/typings.py b/src/eva/core/data/datasets/typings.py new file mode 100644 index 000000000..465b23e25 --- /dev/null +++ b/src/eva/core/data/datasets/typings.py @@ -0,0 +1,18 @@ +"""Typing definitions for the datasets module.""" + +from typing import Any, Dict, NamedTuple + +import torch + + +class DataSample(NamedTuple): + """The default input batch data scheme.""" + + data: torch.Tensor + """The data batch.""" + + targets: torch.Tensor | None = None + """The target batch.""" + + metadata: Dict[str, Any] | None = None + """The associated metadata.""" diff --git a/src/eva/core/data/samplers/__init__.py b/src/eva/core/data/samplers/__init__.py index 5cc3a852e..7586d533a 100644 --- a/src/eva/core/data/samplers/__init__.py +++ b/src/eva/core/data/samplers/__init__.py @@ -1,5 +1,7 @@ """Data samplers API.""" -from eva.core.data.samplers.sampler import Sampler +from eva.core.data.samplers.classification.balanced import BalancedSampler +from eva.core.data.samplers.random import RandomSampler +from eva.core.data.samplers.sampler import Sampler, SamplerWithDataSource -__all__ = ["Sampler"] +__all__ = ["Sampler", "SamplerWithDataSource", "RandomSampler", "BalancedSampler"] diff --git a/src/eva/core/data/samplers/classification/__init__.py b/src/eva/core/data/samplers/classification/__init__.py new file mode 100644 index 000000000..c68235bcc --- /dev/null +++ b/src/eva/core/data/samplers/classification/__init__.py @@ -0,0 +1,5 @@ +"""Classification data samplers API.""" + +from eva.core.data.samplers.classification.balanced import BalancedSampler + +__all__ = ["BalancedSampler"] diff --git a/src/eva/core/data/samplers/classification/balanced.py b/src/eva/core/data/samplers/classification/balanced.py new file mode 100644 index 000000000..ed3a19d39 --- /dev/null +++ b/src/eva/core/data/samplers/classification/balanced.py @@ -0,0 +1,96 @@ +"""Random class sampler for data loading.""" + +from collections import defaultdict +from typing import Dict, Iterator, List + +import numpy as np +from typing_extensions import override + +from eva.core.data import datasets +from eva.core.data.datasets.typings import DataSample +from eva.core.data.samplers.sampler import SamplerWithDataSource +from eva.core.utils.progress_bar import tqdm + + +class BalancedSampler(SamplerWithDataSource[int]): + """Balanced class sampler for data loading. + + The sampler ensures that: + 1. Each class has the same number of samples + 2. Samples within each class are randomly selected + 3. Samples of different classes appear in random order + """ + + def __init__(self, num_samples: int, replacement: bool = False, seed: int | None = 42): + """Initializes the balanced sampler. + + Args: + num_samples: The number of samples to draw per class. + replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` + seed: Random seed for reproducibility. + """ + self._num_samples = num_samples + self._replacement = replacement + self._class_indices: Dict[int, List[int]] = defaultdict(list) + self._random_generator = np.random.default_rng(seed) + + def __len__(self) -> int: + """Returns the total number of samples.""" + return self._num_samples * len(self._class_indices) + + def __iter__(self) -> Iterator[int]: + """Creates an iterator that yields indices in a class balanced way. + + Returns: + Iterator yielding dataset indices. + """ + indices = [] + + for class_idx in self._class_indices: + class_indices = self._class_indices[class_idx] + sampled_indices = self._random_generator.choice( + class_indices, size=self._num_samples, replace=self._replacement + ).tolist() + indices.extend(sampled_indices) + + self._random_generator.shuffle(indices) + + return iter(indices) + + @override + def set_dataset(self, data_source: datasets.MapDataset): + """Sets the dataset and builds class indices. + + Args: + data_source: The dataset to sample from. + + Raises: + ValueError: If the dataset doesn't have targets or if any class has + fewer samples than `num_samples` and `replacement` is `False`. + """ + super().set_dataset(data_source) + self._make_indices() + + def _make_indices(self): + """Builds indices for each class in the dataset.""" + self._class_indices.clear() + + for idx in tqdm( + range(len(self.data_source)), desc="Fetching class indices for balanced sampler" + ): + _, target, _ = DataSample(*self.data_source[idx]) + if target is None: + raise ValueError("The dataset must return non-empty targets.") + if target.numel() != 1: + raise ValueError("The dataset must return a single & scalar target.") + + class_idx = int(target.item()) + self._class_indices[class_idx].append(idx) + + if not self._replacement: + for class_idx, indices in self._class_indices.items(): + if len(indices) < self._num_samples: + raise ValueError( + f"Class {class_idx} has only {len(indices)} samples, " + f"which is less than the required {self._num_samples} samples." + ) diff --git a/src/eva/core/data/samplers/random.py b/src/eva/core/data/samplers/random.py new file mode 100644 index 000000000..415b8ae3e --- /dev/null +++ b/src/eva/core/data/samplers/random.py @@ -0,0 +1,39 @@ +"""Random sampler for data loading.""" + +from typing import Optional + +from torch.utils import data +from typing_extensions import override + +from eva.core.data import datasets +from eva.core.data.samplers.sampler import SamplerWithDataSource + + +class RandomSampler(data.RandomSampler, SamplerWithDataSource[int]): + """Samples elements randomly.""" + + data_source: datasets.MapDataset # type: ignore + + def __init__( + self, replacement: bool = False, num_samples: Optional[int] = None, generator=None + ) -> None: + """Initializes the random sampler. + + Args: + data_source: dataset to sample from + replacement: samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples: number of samples to draw, default=`len(dataset)`. + generator: Generator used in sampling. + """ + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + @override + def set_dataset(self, data_source: datasets.MapDataset) -> None: + super().__init__( + data_source, + replacement=self.replacement, + num_samples=self.num_samples, + generator=self.generator, + ) diff --git a/src/eva/core/data/samplers/sampler.py b/src/eva/core/data/samplers/sampler.py index 98b3124b2..ff878fa36 100644 --- a/src/eva/core/data/samplers/sampler.py +++ b/src/eva/core/data/samplers/sampler.py @@ -1,6 +1,33 @@ """Core data sampler.""" +from typing import Generic, TypeVar + from torch.utils import data +from eva.core.data import datasets + Sampler = data.Sampler """Core abstract data sampler class.""" + +T_co = TypeVar("T_co", covariant=True) + + +class SamplerWithDataSource(Sampler, Generic[T_co]): + """A sampler base class that enables to specify the data source after initialization. + + The `set_dataset` can also be overwritten to expand the functionality of the derived + sampler classes. + """ + + data_source: datasets.MapDataset + + def set_dataset(self, data_source: datasets.MapDataset) -> None: + """Sets the dataset to sample from. + + This is not done in the constructor because the dataset might not be + available at that time. + + Args: + data_source: The dataset to sample from. + """ + self.data_source = data_source diff --git a/src/eva/core/models/__init__.py b/src/eva/core/models/__init__.py index 16cfca96e..a5f81a151 100644 --- a/src/eva/core/models/__init__.py +++ b/src/eva/core/models/__init__.py @@ -2,7 +2,13 @@ from eva.core.models.modules import HeadModule, InferenceModule from eva.core.models.networks import MLP -from eva.core.models.wrappers import BaseModel, HuggingFaceModel, ModelFromFunction, ONNXModel +from eva.core.models.wrappers import ( + BaseModel, + HuggingFaceModel, + ModelFromFunction, + ONNXModel, + TorchHubModel, +) __all__ = [ "HeadModule", @@ -12,4 +18,5 @@ "HuggingFaceModel", "ModelFromFunction", "ONNXModel", + "TorchHubModel", ] diff --git a/src/eva/core/models/wrappers/__init__.py b/src/eva/core/models/wrappers/__init__.py index 95ab6101d..979577bd1 100644 --- a/src/eva/core/models/wrappers/__init__.py +++ b/src/eva/core/models/wrappers/__init__.py @@ -2,12 +2,14 @@ from eva.core.models.wrappers.base import BaseModel from eva.core.models.wrappers.from_function import ModelFromFunction +from eva.core.models.wrappers.from_torchhub import TorchHubModel from eva.core.models.wrappers.huggingface import HuggingFaceModel from eva.core.models.wrappers.onnx import ONNXModel __all__ = [ "BaseModel", - "ModelFromFunction", "HuggingFaceModel", + "ModelFromFunction", "ONNXModel", + "TorchHubModel", ] diff --git a/src/eva/core/models/wrappers/from_torchhub.py b/src/eva/core/models/wrappers/from_torchhub.py new file mode 100644 index 000000000..2a80aaf5f --- /dev/null +++ b/src/eva/core/models/wrappers/from_torchhub.py @@ -0,0 +1,93 @@ +"""Model wrapper for torch.hub models.""" + +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.nn as nn +from typing_extensions import override + +from eva.core.models import wrappers +from eva.core.models.wrappers import _utils + + +class TorchHubModel(wrappers.BaseModel): + """Model wrapper for `torch.hub` models.""" + + def __init__( + self, + model_name: str, + repo_or_dir: str, + pretrained: bool = True, + checkpoint_path: str = "", + out_indices: int | Tuple[int, ...] | None = None, + norm: bool = False, + trust_repo: bool = True, + model_kwargs: Dict[str, Any] | None = None, + tensor_transforms: Callable | None = None, + ) -> None: + """Initializes the encoder. + + Args: + model_name: Name of model to instantiate. + repo_or_dir: The torch.hub repository or local directory to load the model from. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + out_indices: Returns last n blocks if `int`, all if `None`, select + matching indices if sequence. + norm: Wether to apply norm layer to all intermediate features. Only + used when `out_indices` is not `None`. + trust_repo: If set to `False`, a prompt will ask the user whether the + repo should be trusted. + model_kwargs: Extra model arguments. + tensor_transforms: The transforms to apply to the output tensor + produced by the model. + """ + super().__init__(tensor_transforms=tensor_transforms) + + self._model_name = model_name + self._repo_or_dir = repo_or_dir + self._pretrained = pretrained + self._checkpoint_path = checkpoint_path + self._out_indices = out_indices + self._norm = norm + self._trust_repo = trust_repo + self._model_kwargs = model_kwargs or {} + + self.load_model() + + @override + def load_model(self) -> None: + """Builds and loads the torch.hub model.""" + self._model: nn.Module = torch.hub.load( + repo_or_dir=self._repo_or_dir, + model=self._model_name, + trust_repo=self._trust_repo, + pretrained=self._pretrained, + **self._model_kwargs, + ) # type: ignore + + if self._checkpoint_path: + _utils.load_model_weights(self._model, self._checkpoint_path) + + TorchHubModel.__name__ = self._model_name + + @override + def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]: + if self._out_indices is not None: + if not hasattr(self._model, "get_intermediate_layers"): + raise ValueError( + "Only models with `get_intermediate_layers` are supported " + "when using `out_indices`." + ) + + return list( + self._model.get_intermediate_layers( + tensor, + self._out_indices, + reshape=True, + return_class_token=False, + norm=self._norm, + ) + ) + + return self._model(tensor) diff --git a/src/eva/core/trainers/functional.py b/src/eva/core/trainers/functional.py index 62229bf81..4d8bd5346 100644 --- a/src/eva/core/trainers/functional.py +++ b/src/eva/core/trainers/functional.py @@ -96,11 +96,13 @@ def fit_and_validate( A tuple of with the validation and the test metrics (if exists). """ trainer.fit(model, datamodule=datamodule) - validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose) + validation_scores = trainer.validate( + datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type + ) test_scores = ( None if datamodule.datasets.test is None - else trainer.test(datamodule=datamodule, verbose=verbose) + else trainer.test(datamodule=datamodule, verbose=verbose, ckpt_path=trainer.checkpoint_type) ) return validation_scores, test_scores diff --git a/src/eva/core/trainers/trainer.py b/src/eva/core/trainers/trainer.py index 006470339..beace9db3 100644 --- a/src/eva/core/trainers/trainer.py +++ b/src/eva/core/trainers/trainer.py @@ -1,7 +1,7 @@ """Core trainer module.""" import os -from typing import Any +from typing import Any, Literal import loguru from lightning.pytorch import loggers as pl_loggers @@ -28,6 +28,7 @@ def __init__( *args: Any, default_root_dir: str = "logs", n_runs: int = 1, + checkpoint_type: Literal["best", "last"] = "best", **kwargs: Any, ) -> None: """Initializes the trainer. @@ -40,11 +41,14 @@ def __init__( Unlike in ::class::`lightning.pytorch.Trainer`, this path would be the prioritized destination point. n_runs: The amount of runs (fit and evaluate) to perform in an evaluation session. + checkpoint_type: Wether to load the "best" or "last" checkpoint saved by the checkpoint + callback for evaluations on validation & test sets. kwargs: Kew-word arguments of ::class::`lightning.pytorch.Trainer`. """ super().__init__(*args, default_root_dir=default_root_dir, **kwargs) - self._n_runs = n_runs + self.checkpoint_type = checkpoint_type + self.n_runs = n_runs self._session_id: str = _logging.generate_session_id() self._log_dir: str = self.default_log_dir @@ -106,6 +110,6 @@ def run_evaluation_session( base_trainer=self, base_model=model, datamodule=datamodule, - n_runs=self._n_runs, - verbose=self._n_runs > 1, + n_runs=self.n_runs, + verbose=self.n_runs > 1, ) diff --git a/src/eva/vision/data/datasets/vision.py b/src/eva/vision/data/datasets/vision.py index 81b08f57d..ca3387651 100644 --- a/src/eva/vision/data/datasets/vision.py +++ b/src/eva/vision/data/datasets/vision.py @@ -9,7 +9,7 @@ """The data sample type.""" -class VisionDataset(base.Dataset, abc.ABC, Generic[DataSample]): +class VisionDataset(base.MapDataset, abc.ABC, Generic[DataSample]): """Base dataset class for vision tasks.""" @abc.abstractmethod @@ -24,20 +24,3 @@ def filename(self, index: int) -> str: Returns: The filename of the `index`'th data sample. """ - - @abc.abstractmethod - def __getitem__(self, index: int) -> DataSample: - """Returns the `index`'th data sample. - - Args: - index: The index of the data-sample to select. - - Returns: - A data sample and its target. - """ - raise NotImplementedError - - @abc.abstractmethod - def __len__(self) -> int: - """Returns the total length of the data.""" - raise NotImplementedError diff --git a/src/eva/vision/losses/dice.py b/src/eva/vision/losses/dice.py index 8e6133b34..d5d31d17a 100644 --- a/src/eva/vision/losses/dice.py +++ b/src/eva/vision/losses/dice.py @@ -45,9 +45,6 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: inputs, targets = _apply_ignore_index(inputs, targets, self.ignore_index) targets = _to_one_hot(targets, num_classes=inputs.shape[1]) - if targets.ndim == 3: - targets = one_hot(targets[:, None, ...], num_classes=inputs.shape[1]) - return super().forward(inputs, targets) diff --git a/src/eva/vision/models/networks/backbones/__init__.py b/src/eva/vision/models/networks/backbones/__init__.py index 0fdf2963a..1ef7bc855 100644 --- a/src/eva/vision/models/networks/backbones/__init__.py +++ b/src/eva/vision/models/networks/backbones/__init__.py @@ -1,6 +1,6 @@ """Vision Model Backbones API.""" -from eva.vision.models.networks.backbones import pathology, timm, universal +from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model -__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"] +__all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/__init__.py b/src/eva/vision/models/networks/backbones/torchhub/__init__.py new file mode 100644 index 000000000..6acd97978 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/__init__.py @@ -0,0 +1,5 @@ +"""torch.hub backbones API.""" + +from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model + +__all__ = ["torch_hub_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/backbones.py b/src/eva/vision/models/networks/backbones/torchhub/backbones.py new file mode 100644 index 000000000..d1503a801 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/backbones.py @@ -0,0 +1,61 @@ +"""torch.hub backbones.""" + +import functools +from typing import Tuple + +import torch +from loguru import logger +from torch import nn + +from eva.core.models import wrappers +from eva.vision.models.networks.backbones.registry import BackboneModelRegistry + +HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"] +"""List of torch.hub repositories for which to add the models to the registry.""" + + +def torch_hub_model( + model_name: str, + repo_or_dir: str, + checkpoint_path: str | None = None, + pretrained: bool = False, + out_indices: int | Tuple[int, ...] | None = None, + **kwargs, +) -> nn.Module: + """Initializes any ViT model from torch.hub with weights from a specified checkpoint. + + Args: + model_name: The name of the model to load. + repo_or_dir: The torch.hub repository or local directory to load the model from. + checkpoint_path: The path to the checkpoint file. + pretrained: If set to `True`, load pretrained model weights if available. + out_indices: Whether and which multi-level patch embeddings to return. + **kwargs: Additional arguments to pass to the model + + Returns: + The VIT model instance. + """ + logger.info( + f"Loading torch.hub model {model_name} from {repo_or_dir}" + + (f"using checkpoint {checkpoint_path}" if checkpoint_path else "") + ) + + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + pretrained=pretrained, + checkpoint_path=checkpoint_path or "", + out_indices=out_indices, + model_kwargs=kwargs, + ) + + +BackboneModelRegistry._registry.update( + { + f"torchhub/{repo}:{model_name}": functools.partial( + torch_hub_model, model_name=model_name, repo_or_dir=repo + ) + for repo in HUB_REPOS + for model_name in torch.hub.list(repo, verbose=False) + } +) diff --git a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py index c43b351c4..ce242713f 100644 --- a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py +++ b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py @@ -52,7 +52,7 @@ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torc """ if isinstance(features, torch.Tensor): features = [features] - if not isinstance(features, list) or features[0].ndim != 4: + if not isinstance(features, (list, tuple)) or features[0].ndim != 4: raise ValueError( "Input features should be a list of four (4) dimensional inputs of " "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." diff --git a/src/eva/vision/models/wrappers/__init__.py b/src/eva/vision/models/wrappers/__init__.py index 14d63b687..d2f84de45 100644 --- a/src/eva/vision/models/wrappers/__init__.py +++ b/src/eva/vision/models/wrappers/__init__.py @@ -3,4 +3,4 @@ from eva.vision.models.wrappers.from_registry import ModelFromRegistry from eva.vision.models.wrappers.from_timm import TimmModel -__all__ = ["TimmModel", "ModelFromRegistry"] +__all__ = ["ModelFromRegistry", "TimmModel"] diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index e90f919d5..49ca8fdaa 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -54,7 +54,6 @@ def save_array_as_nifti( dtype: The data type to save the image. """ nifti_image = nib.Nifti1Image(array, affine=np.eye(4), dtype=dtype) # type: ignore - nifti_image.header.get_xyzt_units() nifti_image.to_filename(filename) diff --git a/tests/eva/core/data/samplers/__init__.py b/tests/eva/core/data/samplers/__init__.py new file mode 100644 index 000000000..39e9f73a3 --- /dev/null +++ b/tests/eva/core/data/samplers/__init__.py @@ -0,0 +1 @@ +"""Tests for data loader samplers.""" diff --git a/tests/eva/core/data/samplers/_utils.py b/tests/eva/core/data/samplers/_utils.py new file mode 100644 index 000000000..7e09996a0 --- /dev/null +++ b/tests/eva/core/data/samplers/_utils.py @@ -0,0 +1,30 @@ +"""Test utilities for dataloader sampler tests.""" + +from typing import List, Tuple + +import torch +from typing_extensions import override + +from eva.core.data import datasets + + +class MockDataset(datasets.MapDataset): + """Mock map-style dataset class for unit testing.""" + + def __init__(self, samples: List[Tuple[None, torch.Tensor, None]]): + self.samples = samples + + @override + def __getitem__(self, idx): + return self.samples[idx] + + @override + def __len__(self): + return len(self.samples) + + +def multiclass_dataset(num_samples: int, num_classes: int) -> datasets.MapDataset: + samples = ( + [(None, torch.tensor([i]), None)] * (num_samples // num_classes) for i in range(num_classes) + ) + return MockDataset([item for sublist in samples for item in sublist]) diff --git a/tests/eva/core/data/samplers/classification/__init__.py b/tests/eva/core/data/samplers/classification/__init__.py new file mode 100644 index 000000000..ae2610a37 --- /dev/null +++ b/tests/eva/core/data/samplers/classification/__init__.py @@ -0,0 +1 @@ +"""Tests for classification data loader samplers.""" diff --git a/tests/eva/core/data/samplers/classification/test_balanced.py b/tests/eva/core/data/samplers/classification/test_balanced.py new file mode 100644 index 000000000..ea30a08a0 --- /dev/null +++ b/tests/eva/core/data/samplers/classification/test_balanced.py @@ -0,0 +1,75 @@ +"""Tests for the balanced sampler.""" + +from collections import Counter + +import pytest +import torch + +from eva.core.data.datasets.typings import DataSample +from eva.core.data.samplers.classification import BalancedSampler +from tests.eva.core.data.samplers import _utils + + +@pytest.mark.parametrize( + "num_class_samples, replacement, num_dataset_samples, num_classes", + [ + (3, False, 15, 2), + (20, True, 15, 2), + (3, False, 33, 5), + ], +) +def test_balanced_sampling( + num_class_samples: int, replacement: bool, num_dataset_samples: int, num_classes: int +): + """Tests if the returned indices are balanced.""" + dataset = _utils.multiclass_dataset(num_dataset_samples, num_classes) + sampler = BalancedSampler(num_samples=num_class_samples, replacement=replacement) + sampler.set_dataset(dataset) + + indices = list(sampler) + class_counts = Counter(DataSample(*dataset[i]).targets.item() for i in indices) # type: ignore + + assert len(sampler) == num_class_samples * num_classes + assert len(class_counts.keys()) == num_classes + for count in class_counts.values(): + assert count == num_class_samples + + +def test_insufficient_samples_without_replacement(): + """Tests if the sampler raises an error when there are insufficient samples.""" + num_dataset_samples, num_classes = 15, 3 + dataset = _utils.multiclass_dataset(num_dataset_samples, num_classes) + sampler = BalancedSampler(num_samples=7, replacement=False) + + with pytest.raises(ValueError, match=f"has only {num_dataset_samples // num_classes} samples"): + sampler.set_dataset(dataset) + + +def test_random_seed(): + """Tests if the sampler is reproducible with the same seed.""" + num_dataset_samples, num_classes = 101, 3 + dataset = _utils.multiclass_dataset(num_dataset_samples, num_classes) + sampler1 = BalancedSampler(num_samples=10, seed=1) + sampler1_duplicate = BalancedSampler(num_samples=10, seed=1) + sampler2 = BalancedSampler(num_samples=10, seed=2) + sampler1.set_dataset(dataset) + sampler1_duplicate.set_dataset(dataset) + sampler2.set_dataset(dataset) + + assert list(sampler1) == list(sampler1_duplicate) + assert list(sampler1) != list(sampler2) + + +def test_invalid_targets(): + """Tests if the sampler raises an error unsupported target formats.""" + sampler = BalancedSampler(num_samples=10) + + # test multi-dimensional target + dataset = _utils.MockDataset([(None, torch.tensor([0, 1]), None)]) + with pytest.raises(ValueError, match="single & scalar target"): + sampler.set_dataset(dataset) + + # test empty target + dataset = _utils.MockDataset([(None, None, None)]) # type: ignore + with pytest.raises(ValueError, match="non-empty targets"): + sampler.set_dataset(dataset) diff --git a/tests/eva/core/models/wrappers/test_from_torchub.py b/tests/eva/core/models/wrappers/test_from_torchub.py new file mode 100644 index 000000000..bf2752347 --- /dev/null +++ b/tests/eva/core/models/wrappers/test_from_torchub.py @@ -0,0 +1,76 @@ +"""TorchHubModel tests.""" + +from typing import Any, Dict, Tuple + +import pytest +import torch + +from eva.core.models import wrappers + + +@pytest.mark.parametrize( + "model_name, repo_or_dir, out_indices, model_kwargs, " + "input_tensor, expected_len, expected_shape", + [ + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + None, + None, + torch.Tensor(2, 3, 224, 224), + None, + torch.Size([2, 384]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 1, + None, + torch.Tensor(2, 3, 224, 224), + 1, + torch.Size([2, 384, 16, 16]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 3, + None, + torch.Tensor(2, 3, 224, 224), + 3, + torch.Size([2, 384, 16, 16]), + ), + ], +) +def test_torchhub_model( + torchhub_model: wrappers.TorchHubModel, + input_tensor: torch.Tensor, + expected_len: int | None, + expected_shape: torch.Size, +) -> None: + """Tests the torch.hub model wrapper.""" + outputs = torchhub_model(input_tensor) + if torchhub_model._out_indices is not None: + assert isinstance(outputs, list) or isinstance(outputs, tuple) + assert len(outputs) == expected_len + assert isinstance(outputs[0], torch.Tensor) + assert outputs[0].shape == expected_shape + else: + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == expected_shape + + +@pytest.fixture(scope="function") +def torchhub_model( + model_name: str, + repo_or_dir: str, + out_indices: int | Tuple[int, ...] | None, + model_kwargs: Dict[str, Any] | None, +) -> wrappers.TorchHubModel: + """TorchHubModel fixture.""" + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + out_indices=out_indices, + model_kwargs=model_kwargs, + pretrained=False, + ) diff --git a/tools/data/leaderboard.csv b/tools/data/leaderboard.csv index f30a3a33d..678d4bca1 100644 --- a/tools/data/leaderboard.csv +++ b/tools/data/leaderboard.csv @@ -1,12 +1,14 @@ bach,crc,mhist,patch_camelyon,camelyon16_small,panda_small,consep,monusac,model -0.783,0.94,0.773,0.901,0.767,0.625,0.63,0.537,dino_vits16_lunit -0.722,0.936,0.799,0.922,0.797,0.64,0.68,0.54,owkin_phikon -0.797,0.947,0.844,0.936,0.834,0.656,0.662,0.554,dino_vitl16_uni -0.758,0.958,0.839,0.942,0.82,0.645,0.69,0.588,bioptimus_h_optimus_0 -0.761,0.952,0.829,0.945,0.814,0.664,0.661,0.558,prov_gigapath -0.816,0.931,0.826,0.951,0.832,0.633,0.69,0.586,histai_hibou_l -0.802,0.938,0.829,0.904,0.789,0.618,0.611,0.549,dino_vits16_kaiko -0.829,0.952,0.814,0.885,0.814,0.654,0.688,0.599,dino_vits8_kaiko -0.835,0.958,0.835,0.907,0.816,0.621,0.636,0.551,dino_vitb16_kaiko -0.858,0.957,0.823,0.918,0.818,0.638,0.703,0.641,dino_vitb8_kaiko -0.864,0.936,0.828,0.908,0.812,0.65,0.679,0.59,dino_vitl14_kaiko +0.88,0.966,0.858,0.936,0.864,0.642,0.723,0.713,paige_virchow2 +0.758,0.958,0.839,0.942,0.82,0.645,0.726,0.725,bioptimus_h_optimus_0 +0.797,0.947,0.844,0.936,0.834,0.656,0.711,0.708,dino_vitl16_uni +0.761,0.952,0.829,0.945,0.814,0.664,0.709,0.724,prov_gigapath +0.816,0.931,0.826,0.951,0.832,0.633,0.725,0.728,histai_hibou_l +0.858,0.957,0.823,0.918,0.818,0.638,0.723,0.736,dino_vitb8_kaiko +0.864,0.936,0.828,0.908,0.812,0.65,0.716,0.727,dino_vitl14_kaiko +0.829,0.952,0.814,0.885,0.814,0.654,0.716,0.712,dino_vits8_kaiko +0.835,0.958,0.835,0.907,0.816,0.621,0.69,0.69,dino_vitb16_kaiko +0.722,0.936,0.799,0.922,0.797,0.64,0.708,0.709,owkin_phikon +0.802,0.938,0.829,0.904,0.789,0.618,0.683,0.694,dino_vits16_kaiko +0.727,0.939,0.775,0.893,0.808,0.635,0.711,0.689,owkin_phikon_v2 +0.783,0.94,0.773,0.901,0.767,0.625,0.68,0.69,dino_vits16_lunit \ No newline at end of file diff --git a/tools/generate_leaderboard_plots.py b/tools/generate_leaderboard_plots.py index 3468fa76b..5e1f02f90 100644 --- a/tools/generate_leaderboard_plots.py +++ b/tools/generate_leaderboard_plots.py @@ -28,8 +28,10 @@ "monusac": "GeneralizedDiceScore", } _fm_name_map = { - "dino_vits16_lunit": "Lunit - ViT-S16 | TCGA", - "owkin_phikon": "Owkin (Phikon) - iBOT ViT-B16 | TCGA", + "paige_virchow2": "Virchow2 - DINOv2 ViT-H14 | 3.1M slides", + "dino_vits16_lunit": "Lunit - DINO ViT-S16 | TCGA", + "owkin_phikon": "Phikon - iBOT ViT-B16 | TCGA", + "owkin_phikon_v2": "Phikon-v2 - DINOv2 ViT-L16 | PANCAN-XL", "dino_vitl16_uni": "UNI - DINOv2 ViT-L16 | Mass-100k", "bioptimus_h_optimus_0": "H-optimus-0 - ViT-G14 | 500k slides", "prov_gigapath": "Prov-GigaPath - DINOv2 ViT-G14 | 181k slides",