diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 4007bd61c..6a8f6c198 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -103,8 +103,8 @@ def apply(x): def condition_layer_film(input_tensor, control_vector, filters): # Transform control into gamma and beta - gamma = layers.Dense(filters, activation="linear")(control_vector) - beta = layers.Dense(filters, activation="linear")(control_vector) + gamma = layers.Dense(filters*2, activation="linear")(control_vector) + beta = layers.Dense(filters*2, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions #gamma = tf.reshape(gamma, (-1,) + input_tensor.shape[1:-1] + (filters,))