Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diffusion bundle support monai service #541

Draft
wants to merge 36 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
aea7312
change to Dataset as in brain segmentation bundle, add support for am…
Can-Zhao Dec 13, 2023
bc2be5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
54bf0d2
typo
Can-Zhao Dec 13, 2023
94ad31e
Merge branch 'diffusion_monai_service' of github.com:Can-Zhao/model-z…
Can-Zhao Dec 13, 2023
9d2d2f9
update train-diffusion.json
Can-Zhao Dec 13, 2023
d51159e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
18033f7
update train-diffusion.json
Can-Zhao Dec 13, 2023
466c3c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
ad7a3a5
update train_autoencoder.json and inference.json
Can-Zhao Dec 13, 2023
dabb18f
Merge branch 'diffusion_monai_service' of github.com:Can-Zhao/model-z…
Can-Zhao Dec 13, 2023
e8b6a02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
4c8646b
typo
Can-Zhao Dec 13, 2023
ce9bdcd
typo
Can-Zhao Dec 13, 2023
c3bd891
Merge branch 'diffusion_monai_service' of github.com:Can-Zhao/model-z…
Can-Zhao Dec 13, 2023
a6d7d85
typo
Can-Zhao Dec 13, 2023
c94a50a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
2d4f8a7
typo
Can-Zhao Dec 13, 2023
fa459d5
Merge branch 'diffusion_monai_service' of github.com:Can-Zhao/model-z…
Can-Zhao Dec 13, 2023
3536210
typo
Can-Zhao Dec 13, 2023
29fc950
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
e6af5f8
typo
Can-Zhao Dec 13, 2023
a10e07e
Merge branch 'diffusion_monai_service' of github.com:Can-Zhao/model-z…
Can-Zhao Dec 13, 2023
703e92d
update readme
Can-Zhao Dec 13, 2023
bdbfa10
update loss weights
Can-Zhao Dec 13, 2023
c40bd4f
update readme
Can-Zhao Dec 13, 2023
95e4ddf
flake
Can-Zhao Dec 13, 2023
a24788c
update readme
Can-Zhao Dec 15, 2023
46f9715
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2023
424cbe2
update readme
Can-Zhao Dec 15, 2023
a52db12
update inference_autoencoder.json
Can-Zhao Dec 15, 2023
ace8223
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2023
3618cc8
update readme
Can-Zhao Dec 15, 2023
09427c3
add cache
Can-Zhao Dec 15, 2023
026152a
maximize batch size
Can-Zhao Dec 15, 2023
3ddf7d2
reduce epoch num
Can-Zhao Dec 15, 2023
4839174
update readme
Can-Zhao Dec 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions models/brats_mri_generative_diffusion/configs/inference.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
"output_postfix": "$datetime.now().strftime('sample_%Y%m%d_%H%M%S')",
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 8,
"latent_channels": 4,
"latent_shape": [
8,
36,
44,
28
"@latent_channels",
48,
48,
32
],
"autoencoder_def": {
"_target_": "generative.networks.nets.AutoencoderKL",
Expand All @@ -39,15 +39,17 @@
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false
"with_decoder_nonlocal_attn": false,
"use_checkpointing": true,
"use_convtranspose": false
},
"network_def": {
"_target_": "generative.networks.nets.DiffusionModelUNet",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels": [
256,
128,
256,
512
],
Expand All @@ -58,10 +60,11 @@
],
"num_head_channels": [
0,
64,
64
32,
32
],
"num_res_blocks": 2
"num_res_blocks": 2,
"use_flash_attention": true
},
"load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
"load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
"imports": [
"$import torch",
"$from datetime import datetime",
"$from pathlib import Path"
"$from pathlib import Path",
"$import generative"
],
"bundle_root": ".",
"model_dir": "$@bundle_root + '/models'",
"dataset_dir": "/workspace/data/medical",
"data_list_file_path": "$@bundle_root + '/configs/datalist.json'",
"dataset_dir": "/datasets/brats18",
"test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='testing', base_dir=@dataset_dir)",
"output_dir": "$@bundle_root + '/output'",
"create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
Expand All @@ -20,11 +23,11 @@
],
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 8,
"latent_channels": 4,
"infer_patch_size": [
144,
176,
112
192,
192,
128
],
"autoencoder_def": {
"_target_": "generative.networks.nets.AutoencoderKL",
Expand All @@ -46,7 +49,9 @@
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false
"with_decoder_nonlocal_attn": false,
"use_checkpointing": true,
"use_convtranspose": false
},
"load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
"load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
Expand Down Expand Up @@ -108,13 +113,8 @@
"transforms": "$@preprocessing_transforms + @crop_transforms + @final_transforms"
},
"dataset": {
"_target_": "monai.apps.DecathlonDataset",
"root_dir": "@dataset_dir",
"task": "Task01_BrainTumour",
"section": "validation",
"cache_rate": 0.0,
"num_workers": 8,
"download": false,
"_target_": "Dataset",
"data": "@test_datalist",
"transform": "@preprocessing"
},
"dataloader": {
Expand Down
146 changes: 115 additions & 31 deletions models/brats_mri_generative_diffusion/configs/train_autoencoder.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,35 @@
"imports": [
"$import functools",
"$import glob",
"$import scripts"
"$import scripts",
"$import generative"
],
"bundle_root": ".",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"ckpt_dir": "$@bundle_root + '/models'",
"tf_dir": "$@bundle_root + '/eval'",
"dataset_dir": "/workspace/data/medical",
"data_list_file_path": "$@bundle_root + '/configs/datalist.json'",
"dataset_dir": "/datasets/brats18",
"train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training', base_dir=@dataset_dir)",
"val_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation', base_dir=@dataset_dir)",
"pretrained": false,
"perceptual_loss_model_weights_path": null,
"train_batch_size": 2,
"lr": 1e-05,
"train_batch_size": 4,
"val_batch_size": 3,
"epochs": 3000,
"val_interval": 10,
"lr": 5e-05,
"amp": true,
"train_patch_size": [
112,
128,
112,
80
],
"val_patch_size": [
192,
192,
128
],
"channel": 0,
"spacing": [
1.1,
Expand All @@ -26,7 +39,7 @@
],
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 8,
"latent_channels": 4,
"discriminator_def": {
"_target_": "generative.networks.nets.PatchDiscriminator",
"spatial_dims": "@spatial_dims",
Expand Down Expand Up @@ -56,7 +69,9 @@
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false
"with_decoder_nonlocal_attn": false,
"use_checkpointing": true,
"use_convtranspose": false
},
"perceptual_loss_def": {
"_target_": "generative.losses.PerceptualLoss",
Expand Down Expand Up @@ -114,9 +129,12 @@
"keys": "image",
"pixdim": "@spacing",
"mode": "bilinear"
}
],
"final_transforms": [
},
{
"_target_": "CenterSpatialCropd",
"keys": "image",
"roi_size": "@val_patch_size"
},
{
"_target_": "ScaleIntensityRangePercentilesd",
"keys": "image",
Expand All @@ -137,17 +155,13 @@
],
"preprocessing": {
"_target_": "Compose",
"transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms"
"transforms": "$@preprocessing_transforms + @train#crop_transforms"
},
"dataset": {
"_target_": "monai.apps.DecathlonDataset",
"root_dir": "@dataset_dir",
"task": "Task01_BrainTumour",
"section": "training",
"cache_rate": 1.0,
"num_workers": 8,
"download": false,
"transform": "@train#preprocessing"
"_target_": "CacheDataset",
"data": "@train_datalist",
"transform": "@train#preprocessing",
"cache_rate": 1.0
},
"dataloader": {
"_target_": "DataLoader",
Expand All @@ -158,32 +172,33 @@
},
"handlers": [
{
"_target_": "CheckpointSaver",
"save_dir": "@ckpt_dir",
"save_dict": {
"model": "@gnetwork"
},
"save_interval": 0,
"save_final": true,
"_target_": "ValidationHandler",
"validator": "@validate#evaluator",
"epoch_level": true,
"final_filename": "model_autoencoder.pt"
"interval": "@val_interval"
},
{
"_target_": "StatsHandler",
"tag_name": "train_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
"output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]+monai.handlers.from_engine(['d_loss'], first=True)(x)[0]"
},
{
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"tag_name": "train_loss",
"tag_name": "train_generator_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
},
{
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"tag_name": "train_discriminator_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['d_loss'], first=True)(x)[0]"
}
],
"trainer": {
"_target_": "scripts.ldm_trainer.VaeGanTrainer",
"device": "@device",
"max_epochs": 1500,
"max_epochs": "@epochs",
"train_data_loader": "@train#dataloader",
"g_network": "@gnetwork",
"g_optimizer": "@goptimizer",
Expand All @@ -195,7 +210,76 @@
"g_update_latents": true,
"latent_shape": "@latent_channels",
"key_train_metric": "$None",
"train_handlers": "@train#handlers"
"train_handlers": "@train#handlers",
"amp": "@amp"
}
},
"validate": {
"preprocessing": {
"_target_": "Compose",
"transforms": "$@preprocessing_transforms"
},
"dataset": {
"_target_": "CacheDataset",
"data": "@val_datalist",
"transform": "@validate#preprocessing",
"cache_rate": 1.0
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@validate#dataset",
"batch_size": "@val_batch_size",
"shuffle": false,
"num_workers": 4
},
"postprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "Lambdad",
"keys": "pred",
"func": "$lambda x: x[0]"
}
]
},
"handlers": [
{
"_target_": "StatsHandler",
"iteration_log": false
},
{
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"iteration_log": false
},
{
"_target_": "CheckpointSaver",
"save_dir": "@ckpt_dir",
"save_dict": {
"model": "@gnetwork"
},
"save_interval": 0,
"save_final": true,
"epoch_level": true,
"final_filename": "model_autoencoder.pt"
}
],
"key_metric": {
"val_mean_l2": {
"_target_": "MeanSquaredError",
"output_transform": "$monai.handlers.from_engine(['pred', 'image'])"
}
},
"evaluator": {
"_target_": "SupervisedEvaluator",
"device": "@device",
"val_data_loader": "@validate#dataloader",
"network": "@gnetwork",
"postprocessing": "@validate#postprocessing",
"key_val_metric": "$@validate#key_metric",
"metric_cmp_fn": "$lambda current_metric,prev_best: current_metric < prev_best",
"val_handlers": "@validate#handlers",
"amp": "@amp"
}
},
"initialize": [
Expand Down
Loading
Loading