diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index 5ee920b56..366c9746e 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -877,7 +877,7 @@ def test_step(self, batch): # this is computationally demanding, kid_diffusion_steps has to be small if self.input_map.axes() == 3 and self.inspect_model: images = self.denormalize(images) - generated_images = self.generate( + generated_images = self.generate(control_embed, num_images=self.batch_size, diffusion_steps=20 ) self.kid.update_state(images, generated_images)