diff --git a/README.md b/README.md index 93fb016..13314b7 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ TODO LIST: - [ ] Add synthetic sentences based on other source of information - [ ] Maybe use LLM to augment the reports - [ ] Add warmup time for the diffusion model +- [ ] Include images from ChestX-ray14 https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345 ## C1 diff --git a/configs/stage1/aekl_v0.yaml b/configs/stage1/aekl_v0.yaml index 2b86abc..1ca280c 100644 --- a/configs/stage1/aekl_v0.yaml +++ b/configs/stage1/aekl_v0.yaml @@ -8,7 +8,7 @@ stage1: spatial_dims: 2 in_channels: 1 out_channels: 1 - num_channels: [64, 128, 128, 128] + num_channels: [64, 128, 128, 256] latent_channels: 3 num_res_blocks: 2 attention_levels: [False, False, False, False] @@ -22,11 +22,6 @@ discriminator: num_layers_d: 3 in_channels: 1 out_channels: 1 - kernel_size: 4 - activation: "LEAKYRELU" - norm: "BATCH" - bias: False - padding: 1 perceptual_network: params: diff --git a/src/python/testing/generate_sample_local.py b/src/python/testing/generate_sample_local.py index f1a4fb7..9f7f03b 100644 --- a/src/python/testing/generate_sample_local.py +++ b/src/python/testing/generate_sample_local.py @@ -103,3 +103,13 @@ plt.imshow(sample.cpu()[0, 0, :, :], cmap="gray", vmin=0, vmax=1) plt.show() + + +torch.save( + diffusion.state_dict(), + "/media/walter/Storage/Projects/GenerativeModels/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/models/diffusion_model.pth", +) +torch.save( + stage1.state_dict(), + "/media/walter/Storage/Projects/GenerativeModels/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/models/autoencoder.pth", +) diff --git a/src/python/training/training_functions_old_disc.py b/src/python/training/training_functions_old_disc.py index d3b9997..b4c6ed0 100644 --- a/src/python/training/training_functions_old_disc.py +++ b/src/python/training/training_functions_old_disc.py @@ -4,28 +4,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from pynvml.smi import nvidia_smi from tensorboardX import SummaryWriter from torch.cuda.amp import GradScaler, autocast from tqdm import tqdm +from training_functions import get_lr, print_gpu_memory_report from util import log_reconstructions -def get_lr(optimizer): - for param_group in optimizer.param_groups: - return param_group["lr"] - - -def print_gpu_memory_report(): - if torch.cuda.is_available(): - nvsmi = nvidia_smi.getInstance() - data = nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")["gpu"] - print("Memory report") - for i, data_by_rank in enumerate(data): - mem_report = data_by_rank["fb_memory_usage"] - print(f"gpu:{i} mem(%) {int(mem_report['used'] * 100.0 / mem_report['total'])}") - - # ---------------------------------------------------------------------------------------------------------------------- # AUTOENCODER KL # ---------------------------------------------------------------------------------------------------------------------- diff --git a/src/python/training/training_functions_original_disc.py b/src/python/training/training_functions_original_disc.py index e2b4205..87abdf8 100644 --- a/src/python/training/training_functions_original_disc.py +++ b/src/python/training/training_functions_original_disc.py @@ -4,28 +4,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from pynvml.smi import nvidia_smi from tensorboardX import SummaryWriter from torch.cuda.amp import GradScaler, autocast from tqdm import tqdm +from training_functions import get_lr, print_gpu_memory_report from util import log_reconstructions -def get_lr(optimizer): - for param_group in optimizer.param_groups: - return param_group["lr"] - - -def print_gpu_memory_report(): - if torch.cuda.is_available(): - nvsmi = nvidia_smi.getInstance() - data = nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")["gpu"] - print("Memory report") - for i, data_by_rank in enumerate(data): - mem_report = data_by_rank["fb_memory_usage"] - print(f"gpu:{i} mem(%) {int(mem_report['used'] * 100.0 / mem_report['total'])}") - - def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1.0 - logits_real)) loss_fake = torch.mean(F.relu(1.0 + logits_fake))