diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 974fc984a..6698adf9b 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -108,9 +108,12 @@ def condition_layer_film(input_tensor, control_vector, filters): beta = layers.Dense(filters, activation="linear")(control_vector) # Reshape gamma and beta to match the spatial dimensions - gamma = tf.reshape(gamma, (-1, 1, 1, filters)) - beta = tf.reshape(beta, (-1, 1, 1, filters)) - + if 4 == len(input_tensor.shape): + gamma = tf.reshape(gamma, (-1, 1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, 1, filters)) + elif 3 == len(input_tensor.shape): + gamma = tf.reshape(gamma, (-1, 1, filters)) + beta = tf.reshape(beta, (-1, 1, filters)) # Apply FiLM (Feature-wise Linear Modulation) return input_tensor * gamma + beta