diff --git a/implementations/began/began.py b/implementations/began/began.py index d9d420d8..d6243217 100644 --- a/implementations/began/began.py +++ b/implementations/began/began.py @@ -2,6 +2,7 @@ import os import numpy as np import math +import time import torchvision.transforms as transforms from torchvision.utils import save_image @@ -27,12 +28,17 @@ parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels") +parser.add_argument('--inference', action='store_true', default=False) +parser.add_argument('--precision', default='float32', help='Precision, "float32" or "bfloat16"') +parser.add_argument('--channels_last', type=int, default=1, help='use channels last format') +parser.add_argument('--num-iterations', default=100, type=int) opt = parser.parse_args() print(opt) img_shape = (opt.channels, opt.img_size, opt.img_size) cuda = True if torch.cuda.is_available() else False +Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor def weights_init_normal(m): @@ -68,6 +74,8 @@ def __init__(self): def forward(self, noise): out = self.l1(noise) out = out.view(out.shape[0], 128, self.init_size, self.init_size) + if opt.channels_last: + out = out.to(memory_format=torch.channels_last) img = self.conv_blocks(out) return img @@ -94,116 +102,167 @@ def __init__(self): def forward(self, img): out = self.down(img) + if opt.channels_last: + out = out.contiguous() out = self.fc(out.view(out.size(0), -1)) - out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size)) + out = out.view(out.size(0), 64, self.down_size, self.down_size) + if opt.channels_last: + out = out.to(memory_format=torch.channels_last) + out = self.up(out) return out +def main(): + # Initialize generator and discriminator + generator = Generator() + discriminator = Discriminator() + + if cuda: + generator.cuda() + discriminator.cuda() + else: + generator.cpu() + discriminator.cpu() + + # Initialize weights + generator.apply(weights_init_normal) + discriminator.apply(weights_init_normal) + device = torch.device('cuda') if cuda else torch.device('cpu') + if opt.inference: + print("----------------Generation---------------") + if opt.precision == "bfloat16": + cm = torch.cuda.amp.autocast if cuda else torch.cpu.amp.autocast + with cm(): + generate(generator, device=device) + else: + generate(generator, device=device) + else: + print("-------------------Train-----------------") + train(generator, discriminator) + + +def generate(netG, device): + fixed_noise = Variable(Tensor(np.random.normal(0, 1, (10 ** 2, opt.latent_dim)))) + if opt.channels_last: + netG_oob = netG + try: + netG_oob = netG_oob.to(memory_format=torch.channels_last) + print("[INFO] Use NHWC model") + except: + print("[WARN] Input NHWC failed! Use normal model") + netG = netG_oob + else: + fixed_noise = fixed_noise.to(device=device) + netG.eval() + + total_iters = opt.num_iterations + with torch.no_grad(): + tic = time.time() + for i in range(total_iters): + fake = netG(fixed_noise) + toc = time.time() - tic + print("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms"%((opt.num_iterations*opt.batch_size)/toc, opt.batch_size, 1000*toc/opt.num_iterations)) -# Initialize generator and discriminator -generator = Generator() -discriminator = Discriminator() - -if cuda: - generator.cuda() - discriminator.cuda() - -# Initialize weights -generator.apply(weights_init_normal) -discriminator.apply(weights_init_normal) - -# Configure data loader -os.makedirs("../../data/mnist", exist_ok=True) -dataloader = torch.utils.data.DataLoader( - datasets.MNIST( - "../../data/mnist", - train=True, - download=True, - transform=transforms.Compose( - [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] - ), - ), - batch_size=opt.batch_size, - shuffle=True, -) - -# Optimizers -optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) -optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) - -Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # ---------- # Training # ---------- -# BEGAN hyper parameters -gamma = 0.75 -lambda_k = 0.001 -k = 0.0 - -for epoch in range(opt.n_epochs): - for i, (imgs, _) in enumerate(dataloader): - - # Configure input - real_imgs = Variable(imgs.type(Tensor)) - - # ----------------- - # Train Generator - # ----------------- - - optimizer_G.zero_grad() - - # Sample noise as generator input - z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) - - # Generate a batch of images - gen_imgs = generator(z) - - # Loss measures generator's ability to fool the discriminator - g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs)) - - g_loss.backward() - optimizer_G.step() - - # --------------------- - # Train Discriminator - # --------------------- - - optimizer_D.zero_grad() - - # Measure discriminator's ability to classify real from generated samples - d_real = discriminator(real_imgs) - d_fake = discriminator(gen_imgs.detach()) - - d_loss_real = torch.mean(torch.abs(d_real - real_imgs)) - d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach())) - d_loss = d_loss_real - k * d_loss_fake - - d_loss.backward() - optimizer_D.step() - - # ---------------- - # Update weights - # ---------------- - - diff = torch.mean(gamma * d_loss_real - d_loss_fake) - - # Update weight term for fake samples - k = k + lambda_k * diff.item() - k = min(max(k, 0), 1) # Constraint to interval [0, 1] - - # Update convergence metric - M = (d_loss_real + torch.abs(diff)).data[0] - - # -------------- - # Log Progress - # -------------- - - print( - "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f" - % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k) - ) - - batches_done = epoch * len(dataloader) + i - if batches_done % opt.sample_interval == 0: - save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) +def train(netG, netD): + # BEGAN hyper parameters + gamma = 0.75 + lambda_k = 0.001 + k = 0.0 + + # Configure data loader + os.makedirs("../../data/mnist", exist_ok=True) + dataloader = torch.utils.data.DataLoader( + datasets.MNIST( + "../../data/mnist", + train=True, + download=True, + transform=transforms.Compose( + [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + ), + ), + batch_size=opt.batch_size, + shuffle=True, + ) + # Optimizers + optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) + + for epoch in range(opt.n_epochs): + for i, (imgs, _) in enumerate(dataloader): + if opt.channels_last: + imgs_oob = imgs + try: + imgs_oob = imgs_oob.to(memory_format=torch.channels_last) + print("[INFO] Use NHWC input") + except: + print("[WARN] Input NHWC failed! Use normal input") + imgs = imgs_oob + # Configure input + real_imgs = Variable(imgs.type(Tensor)) + + # ----------------- + # Train Generator + # ----------------- + + optimizer_G.zero_grad() + + # Sample noise as generator input + z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) + + # Generate a batch of images + gen_imgs = netG(z) + + # Loss measures generator's ability to fool the discriminator + g_loss = torch.mean(torch.abs(netD(gen_imgs) - gen_imgs)) + + g_loss.backward() + optimizer_G.step() + + # --------------------- + # Train Discriminator + # --------------------- + + optimizer_D.zero_grad() + + # Measure discriminator's ability to classify real from generated samples + d_real = netD(real_imgs) + d_fake = netD(gen_imgs.detach()) + + d_loss_real = torch.mean(torch.abs(d_real - real_imgs)) + d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach())) + d_loss = d_loss_real - k * d_loss_fake + + d_loss.backward() + optimizer_D.step() + + # ---------------- + # Update weights + # ---------------- + + diff = torch.mean(gamma * d_loss_real - d_loss_fake) + + # Update weight term for fake samples + k = k + lambda_k * diff.item() + k = min(max(k, 0), 1) # Constraint to interval [0, 1] + + # Update convergence metric + M = (d_loss_real + torch.abs(diff)).data.item() + + # -------------- + # Log Progress + # -------------- + + print( + "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f" + % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k) + ) + + batches_done = epoch * len(dataloader) + i + if batches_done % opt.sample_interval == 0: + save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) +if __name__ == '__main__': + main()