-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: conditional VAE testcase * feat: overload Utils.vec for upcoming wrapper array changes * feat: CVAE working * fix: handle numerical stability issues * docs: finish * docs: uniform assignment of work * fix: recon * fix: correct image
- Loading branch information
Showing
7 changed files
with
339 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
[deps] | ||
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" | ||
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" | ||
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" | ||
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
|
||
[compat] | ||
ConcreteStructs = "0.2.3" | ||
DataAugmentation = "0.3.2" | ||
Enzyme = "0.13.20" | ||
ImageShow = "0.3.8" | ||
Images = "0.26.1" | ||
Lux = "1.4.1" | ||
MLDatasets = "0.7.18" | ||
MLUtils = "0.4.4" | ||
OneHotArrays = "0.2.6" | ||
Printf = "1.10" | ||
Random = "1.10" | ||
Reactant = "0.2.9" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
# # [Conditional VAE for MNIST using Reactant](@id Conditional-VAE-Tutorial) | ||
|
||
# Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is | ||
# based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/). | ||
|
||
using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation, | ||
ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers, Comonicon, | ||
StableRNGs | ||
|
||
const xdev = reactant_device(; force=true) | ||
const cdev = cpu_device() | ||
|
||
# ## Model Definition | ||
|
||
# First we will define the encoder.It maps the input to a normal distribution in latent | ||
# space and sample a latent vector from that distribution. | ||
|
||
function cvae_encoder( | ||
rng=Random.default_rng(); num_latent_dims::Int, | ||
image_shape::Dims{3}, max_num_filters::Int | ||
) | ||
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters | ||
return @compact(; | ||
embed=Chain( | ||
Chain( | ||
Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1), | ||
BatchNorm(max_num_filters ÷ 4, leakyrelu) | ||
), | ||
Chain( | ||
Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1), | ||
BatchNorm(max_num_filters ÷ 2, leakyrelu) | ||
), | ||
Chain( | ||
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1), | ||
BatchNorm(max_num_filters, leakyrelu) | ||
), | ||
FlattenLayer() | ||
), | ||
proj_mu=Dense(flattened_dim, num_latent_dims; init_bias=zeros32), | ||
proj_log_var=Dense(flattened_dim, num_latent_dims; init_bias=zeros32), | ||
rng) do x | ||
y = embed(x) | ||
|
||
μ = proj_mu(y) | ||
logσ² = proj_log_var(y) | ||
|
||
T = eltype(logσ²) | ||
logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0)) | ||
σ = exp.(logσ² .* T(0.5)) | ||
|
||
## Generate a tensor of random values from a normal distribution | ||
rng = Lux.replicate(rng) | ||
ϵ = randn_like(rng, σ) | ||
|
||
## Reparameterization trick to brackpropagate through sampling | ||
z = ϵ .* σ .+ μ | ||
|
||
@return z, μ, logσ² | ||
end | ||
end | ||
|
||
# Similarly we define the decoder. | ||
|
||
function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int) | ||
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters | ||
return @compact(; | ||
linear=Dense(num_latent_dims, flattened_dim), | ||
upchain=Chain( | ||
Chain( | ||
Upsample(2), | ||
Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1), | ||
BatchNorm(max_num_filters ÷ 2, leakyrelu) | ||
), | ||
Chain( | ||
Upsample(2), | ||
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1), | ||
BatchNorm(max_num_filters ÷ 4, leakyrelu) | ||
), | ||
Chain( | ||
Upsample(2), | ||
Conv((3, 3), max_num_filters ÷ 4 => image_shape[3], | ||
sigmoid; stride=1, pad=1) | ||
) | ||
), | ||
max_num_filters) do x | ||
y = linear(x) | ||
img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :) | ||
@return upchain(img) | ||
end | ||
end | ||
|
||
@concrete struct CVAE <: Lux.AbstractLuxContainerLayer{(:encoder, :decoder)} | ||
encoder <: Lux.AbstractLuxLayer | ||
decoder <: Lux.AbstractLuxLayer | ||
end | ||
|
||
function CVAE(rng=Random.default_rng(); num_latent_dims::Int, | ||
image_shape::Dims{3}, max_num_filters::Int) | ||
decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters) | ||
encoder = cvae_encoder(rng; num_latent_dims, image_shape, max_num_filters) | ||
return CVAE(encoder, decoder) | ||
end | ||
|
||
function (cvae::CVAE)(x, ps, st) | ||
(z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder) | ||
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder) | ||
return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec) | ||
end | ||
|
||
function encode(cvae::CVAE, x, ps, st) | ||
(z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder) | ||
return z, (; encoder=st_enc, st.decoder) | ||
end | ||
|
||
function decode(cvae::CVAE, z, ps, st) | ||
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder) | ||
return x_rec, (; decoder=st_dec, st.encoder) | ||
end | ||
|
||
# ## Loading MNIST | ||
|
||
@concrete struct TensorDataset | ||
dataset | ||
transform | ||
end | ||
|
||
Base.length(ds::TensorDataset) = length(ds.dataset) | ||
|
||
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) | ||
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) | ||
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img) | ||
end | ||
|
||
function loadmnist(batchsize, image_size::Dims{2}) | ||
## Load MNIST: Only 1500 for demonstration purposes | ||
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing | ||
train_dataset = MNIST(; split=:train) | ||
test_dataset = MNIST(; split=:test) | ||
if N !== nothing | ||
train_dataset = train_dataset[1:N] | ||
test_dataset = test_dataset[1:N] | ||
end | ||
|
||
train_transform = ScaleKeepAspect(image_size) |> ImageToTensor() | ||
trainset = TensorDataset(train_dataset, train_transform) | ||
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false) | ||
|
||
return trainloader | ||
end | ||
|
||
# ## Helper Functions | ||
|
||
# Generate an Image Grid from a list of images | ||
|
||
function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int) | ||
total_images = grid_rows * grid_cols | ||
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img | ||
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img) | ||
return cimg' | ||
end | ||
return create_image_grid(imgs, grid_rows, grid_cols) | ||
end | ||
|
||
function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int) | ||
## Check if the number of images matches the grid | ||
total_images = grid_rows * grid_cols | ||
@assert length(images) == total_images | ||
|
||
## Get the size of a single image (assuming all images are the same size) | ||
img_height, img_width = size(images[1]) | ||
|
||
## Create a blank grid canvas | ||
grid_height = img_height * grid_rows | ||
grid_width = img_width * grid_cols | ||
grid_canvas = similar(images[1], grid_height, grid_width) | ||
|
||
## Place each image in the correct position on the canvas | ||
for idx in 1:total_images | ||
row = div(idx - 1, grid_cols) + 1 | ||
col = mod(idx - 1, grid_cols) + 1 | ||
|
||
start_row = (row - 1) * img_height + 1 | ||
start_col = (col - 1) * img_width + 1 | ||
|
||
grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx] | ||
end | ||
|
||
return grid_canvas | ||
end | ||
|
||
function loss_function(model, ps, st, X) | ||
(y, μ, logσ²), st = model(X, ps, st) | ||
reconstruction_loss = MSELoss(; agg=sum)(y, X) | ||
kldiv_loss = -sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) / 2 | ||
loss = reconstruction_loss + kldiv_loss | ||
return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss) | ||
end | ||
|
||
function generate_images( | ||
model, ps, st; num_samples::Int=128, num_latent_dims::Int, decode_compiled=nothing) | ||
z = randn(Float32, num_latent_dims, num_samples) |> get_device((ps, st)) | ||
if decode_compiled === nothing | ||
images, _ = decode(model, z, ps, Lux.testmode(st)) | ||
else | ||
images, _ = decode_compiled(model, z, ps, Lux.testmode(st)) | ||
images = images |> cpu_device() | ||
end | ||
return create_image_grid(images, 8, num_samples ÷ 8) | ||
end | ||
|
||
function reconstruct_images(model, ps, st, X) | ||
(recon, _, _), _ = model(X, ps, Lux.testmode(st)) | ||
recon = recon |> cpu_device() | ||
return create_image_grid(recon, 8, size(X, ndims(X)) ÷ 8) | ||
end | ||
|
||
# ## Training the Model | ||
|
||
function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_filters=64, | ||
seed=0, epochs=50, weight_decay=1e-5, learning_rate=1e-3, num_samples=batchsize) | ||
rng = Xoshiro() | ||
Random.seed!(rng, seed) | ||
|
||
cvae = CVAE(rng; num_latent_dims, image_shape=(image_size..., 1), max_num_filters) | ||
ps, st = Lux.setup(rng, cvae) |> xdev | ||
|
||
z = randn(Float32, num_latent_dims, num_samples) |> xdev | ||
decode_compiled = @compile decode(cvae, z, ps, Lux.testmode(st)) | ||
x = randn(Float32, image_size..., 1, batchsize) |> xdev | ||
cvae_compiled = @compile cvae(x, ps, Lux.testmode(st)) | ||
|
||
train_dataloader = loadmnist(batchsize, image_size) |> xdev | ||
|
||
opt = AdamW(; eta=learning_rate, lambda=weight_decay) | ||
|
||
train_state = Training.TrainState(cvae, ps, st, opt) | ||
|
||
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps)/1e6) | ||
|
||
is_vscode = isdefined(Main, :VSCodeServer) | ||
empty_row, model_img_full = nothing, nothing | ||
|
||
for epoch in 1:epochs | ||
loss_total = 0.0f0 | ||
total_samples = 0 | ||
total_time = 0.0 | ||
|
||
for (i, X) in enumerate(train_dataloader) | ||
throughput_tic = time() | ||
(_, loss, stats, train_state) = Training.single_train_step!( | ||
AutoEnzyme(), loss_function, X, train_state) | ||
throughput_toc = time() | ||
|
||
loss_total += loss | ||
total_samples += size(X, ndims(X)) | ||
total_time += throughput_toc - throughput_tic | ||
|
||
if i % 250 == 0 || i == length(train_dataloader) | ||
throughput = total_samples / total_time | ||
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput | ||
end | ||
end | ||
|
||
train_loss = loss_total / length(train_dataloader) | ||
throughput = total_samples / total_time | ||
@printf "Epoch %d, Train Loss: %.7f, Time: %.4fs, Throughput: %.6f im/s\n" epoch train_loss total_time throughput | ||
|
||
if is_vscode || epoch == epochs | ||
recon_images = reconstruct_images( | ||
cvae_compiled, train_state.parameters, train_state.states, | ||
first(train_dataloader)) | ||
gen_images = generate_images( | ||
cvae, train_state.parameters, train_state.states; | ||
num_samples, num_latent_dims, decode_compiled) | ||
if empty_row === nothing | ||
empty_row = similar(gen_images, image_size[1], size(gen_images, 2)) | ||
fill!(empty_row, 0) | ||
end | ||
model_img_full = vcat(recon_images, empty_row, gen_images) | ||
is_vscode && display(model_img_full) | ||
end | ||
end | ||
|
||
return model_img_full | ||
end | ||
|
||
main() |