Skip to content

Commit

Permalink
feat: conditional VAE (#1157)
Browse files Browse the repository at this point in the history
* feat: conditional VAE testcase

* feat: overload Utils.vec for upcoming wrapper array changes

* feat: CVAE working

* fix: handle numerical stability issues

* docs: finish

* docs: uniform assignment of work

* fix: recon

* fix: correct image
  • Loading branch information
avik-pal authored Jan 3, 2025
1 parent 222e173 commit 6341b3d
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pages = [
"tutorials/intermediate/2_BayesianNN.md",
"tutorials/intermediate/3_HyperNet.md",
"tutorials/intermediate/4_PINN2DPDE.md",
"tutorials/intermediate/5_ConditionalVAE.md",
],
"Advanced" => [
"tutorials/advanced/1_GravitationalWaveForm.md"
Expand Down
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ export default defineConfig({
text: "Training a PINN on 2D PDE",
link: "/tutorials/intermediate/4_PINN2DPDE",
},
{
text: "Conditional VAE for MNIST using Reactant",
link: "/tutorials/intermediate/5_ConditionalVAE",
}
],
},
{
Expand Down
Binary file added docs/src/public/conditional_vae.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ const intermediate = [
src: "../pinn_nested_ad.gif",
caption: "Training a PINN",
desc: "Train a PINN to solve 2D PDEs (using Nested AD)."
},
{
href: "intermediate/5_ConditionalVAE",
src: "../conditional_vae.png",
caption: "Conditional VAE for MNIST using Reactant",
desc: "Train a Conditional VAE to generate images from a latent space."
}
];
Expand Down
14 changes: 11 additions & 3 deletions docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const INTERMEDIATE_TUTORIALS = [
"BayesianNN/main.jl" => "CPU",
"HyperNet/main.jl" => "CUDA",
"PINN2DPDE/main.jl" => "CUDA",
"ConditionalVAE/main.jl" => "CUDA",
]
const ADVANCED_TUTORIALS = [
"GravitationalWaveForm/main.jl" => "CPU",
Expand Down Expand Up @@ -41,9 +42,16 @@ end

const TUTORIALS_BUILDING = if BUILDKITE_PARALLEL_JOB_COUNT > 0
id = parse(Int, ENV["BUILDKITE_PARALLEL_JOB"]) + 1 # Index starts from 0
splits = collect(Iterators.partition(TUTORIALS_WITH_BACKEND,
cld(length(TUTORIALS_WITH_BACKEND), BUILDKITE_PARALLEL_JOB_COUNT)))
id > length(splits) ? [] : splits[id]
splits = Vector{Vector{eltype(TUTORIALS_WITH_BACKEND)}}(
undef, BUILDKITE_PARALLEL_JOB_COUNT)
for i in eachindex(TUTORIALS_WITH_BACKEND)
idx = mod1(i, BUILDKITE_PARALLEL_JOB_COUNT)
if !isassigned(splits, idx)
splits[idx] = Vector{eltype(TUTORIALS_WITH_BACKEND)}()
end
push!(splits[idx], TUTORIALS_WITH_BACKEND[i])
end
(id > length(splits) || !isassigned(splits, id)) ? [] : splits[id]
else
TUTORIALS_WITH_BACKEND
end
Expand Down
30 changes: 30 additions & 0 deletions examples/ConditionalVAE/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[deps]
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3.2"
Enzyme = "0.13.20"
ImageShow = "0.3.8"
Images = "0.26.1"
Lux = "1.4.1"
MLDatasets = "0.7.18"
MLUtils = "0.4.4"
OneHotArrays = "0.2.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.9"
287 changes: 287 additions & 0 deletions examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# # [Conditional VAE for MNIST using Reactant](@id Conditional-VAE-Tutorial)

# Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is
# based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/).

using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation,
ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers, Comonicon,
StableRNGs

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

# ## Model Definition

# First we will define the encoder.It maps the input to a normal distribution in latent
# space and sample a latent vector from that distribution.

function cvae_encoder(
rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int
)
flattened_dim = prod(image_shape[1:2] 8) * max_num_filters
return @compact(;
embed=Chain(
Chain(
Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
),
Chain(
Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
),
Chain(
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1),
BatchNorm(max_num_filters, leakyrelu)
),
FlattenLayer()
),
proj_mu=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
proj_log_var=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
rng) do x
y = embed(x)

μ = proj_mu(y)
logσ² = proj_log_var(y)

T = eltype(logσ²)
logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0))
σ = exp.(logσ² .* T(0.5))

## Generate a tensor of random values from a normal distribution
rng = Lux.replicate(rng)
ϵ = randn_like(rng, σ)

## Reparameterization trick to brackpropagate through sampling
z = ϵ .* σ .+ μ

@return z, μ, logσ²
end
end

# Similarly we define the decoder.

function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int)
flattened_dim = prod(image_shape[1:2] 8) * max_num_filters
return @compact(;
linear=Dense(num_latent_dims, flattened_dim),
upchain=Chain(
Chain(
Upsample(2),
Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
),
Chain(
Upsample(2),
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
),
Chain(
Upsample(2),
Conv((3, 3), max_num_filters ÷ 4 => image_shape[3],
sigmoid; stride=1, pad=1)
)
),
max_num_filters) do x
y = linear(x)
img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :)
@return upchain(img)
end
end

@concrete struct CVAE <: Lux.AbstractLuxContainerLayer{(:encoder, :decoder)}
encoder <: Lux.AbstractLuxLayer
decoder <: Lux.AbstractLuxLayer
end

function CVAE(rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int)
decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters)
encoder = cvae_encoder(rng; num_latent_dims, image_shape, max_num_filters)
return CVAE(encoder, decoder)
end

function (cvae::CVAE)(x, ps, st)
(z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec)
end

function encode(cvae::CVAE, x, ps, st)
(z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
return z, (; encoder=st_enc, st.decoder)
end

function decode(cvae::CVAE, z, ps, st)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return x_rec, (; decoder=st_dec, st.encoder)
end

# ## Loading MNIST

@concrete struct TensorDataset
dataset
transform
end

Base.length(ds::TensorDataset) = length(ds.dataset)

function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
return stack(parent itemdata Base.Fix1(apply, ds.transform), img)
end

function loadmnist(batchsize, image_size::Dims{2})
## Load MNIST: Only 1500 for demonstration purposes
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
train_dataset = MNIST(; split=:train)
test_dataset = MNIST(; split=:test)
if N !== nothing
train_dataset = train_dataset[1:N]
test_dataset = test_dataset[1:N]
end

train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
trainset = TensorDataset(train_dataset, train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false)

return trainloader
end

# ## Helper Functions

# Generate an Image Grid from a list of images

function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int)
total_images = grid_rows * grid_cols
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img)
return cimg'
end
return create_image_grid(imgs, grid_rows, grid_cols)
end

function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int)
## Check if the number of images matches the grid
total_images = grid_rows * grid_cols
@assert length(images) == total_images

## Get the size of a single image (assuming all images are the same size)
img_height, img_width = size(images[1])

## Create a blank grid canvas
grid_height = img_height * grid_rows
grid_width = img_width * grid_cols
grid_canvas = similar(images[1], grid_height, grid_width)

## Place each image in the correct position on the canvas
for idx in 1:total_images
row = div(idx - 1, grid_cols) + 1
col = mod(idx - 1, grid_cols) + 1

start_row = (row - 1) * img_height + 1
start_col = (col - 1) * img_width + 1

grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx]
end

return grid_canvas
end

function loss_function(model, ps, st, X)
(y, μ, logσ²), st = model(X, ps, st)
reconstruction_loss = MSELoss(; agg=sum)(y, X)
kldiv_loss = -sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) / 2
loss = reconstruction_loss + kldiv_loss
return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss)
end

function generate_images(
model, ps, st; num_samples::Int=128, num_latent_dims::Int, decode_compiled=nothing)
z = randn(Float32, num_latent_dims, num_samples) |> get_device((ps, st))
if decode_compiled === nothing
images, _ = decode(model, z, ps, Lux.testmode(st))
else
images, _ = decode_compiled(model, z, ps, Lux.testmode(st))
images = images |> cpu_device()
end
return create_image_grid(images, 8, num_samples ÷ 8)
end

function reconstruct_images(model, ps, st, X)
(recon, _, _), _ = model(X, ps, Lux.testmode(st))
recon = recon |> cpu_device()
return create_image_grid(recon, 8, size(X, ndims(X)) ÷ 8)
end

# ## Training the Model

function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_filters=64,
seed=0, epochs=50, weight_decay=1e-5, learning_rate=1e-3, num_samples=batchsize)
rng = Xoshiro()
Random.seed!(rng, seed)

cvae = CVAE(rng; num_latent_dims, image_shape=(image_size..., 1), max_num_filters)
ps, st = Lux.setup(rng, cvae) |> xdev

z = randn(Float32, num_latent_dims, num_samples) |> xdev
decode_compiled = @compile decode(cvae, z, ps, Lux.testmode(st))
x = randn(Float32, image_size..., 1, batchsize) |> xdev
cvae_compiled = @compile cvae(x, ps, Lux.testmode(st))

train_dataloader = loadmnist(batchsize, image_size) |> xdev

opt = AdamW(; eta=learning_rate, lambda=weight_decay)

train_state = Training.TrainState(cvae, ps, st, opt)

@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps)/1e6)

is_vscode = isdefined(Main, :VSCodeServer)
empty_row, model_img_full = nothing, nothing

for epoch in 1:epochs
loss_total = 0.0f0
total_samples = 0
total_time = 0.0

for (i, X) in enumerate(train_dataloader)
throughput_tic = time()
(_, loss, stats, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
throughput_toc = time()

loss_total += loss
total_samples += size(X, ndims(X))
total_time += throughput_toc - throughput_tic

if i % 250 == 0 || i == length(train_dataloader)
throughput = total_samples / total_time
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput
end
end

train_loss = loss_total / length(train_dataloader)
throughput = total_samples / total_time
@printf "Epoch %d, Train Loss: %.7f, Time: %.4fs, Throughput: %.6f im/s\n" epoch train_loss total_time throughput

if is_vscode || epoch == epochs
recon_images = reconstruct_images(
cvae_compiled, train_state.parameters, train_state.states,
first(train_dataloader))
gen_images = generate_images(
cvae, train_state.parameters, train_state.states;
num_samples, num_latent_dims, decode_compiled)
if empty_row === nothing
empty_row = similar(gen_images, image_size[1], size(gen_images, 2))
fill!(empty_row, 0)
end
model_img_full = vcat(recon_images, empty_row, gen_images)
is_vscode && display(model_img_full)
end
end

return model_img_full
end

main()

0 comments on commit 6341b3d

Please sign in to comment.