Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 17, 2025
1 parent eef9daa commit e3a179c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
20 changes: 13 additions & 7 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ class DiffusionController(keras.Model):
def __init__(
self, tensor_map, output_maps, batch_size, widths, block_depth, conv_x, control_size,
attention_start, attention_heads, attention_modulo, diffusion_loss, sigmoid_beta, condition_strategy,
supervisor = None, supervision_scalar = 0.01,
inspect_model, supervisor = None, supervision_scalar = 0.01,
):
super().__init__()

Expand All @@ -653,6 +653,7 @@ def __init__(
self.beta = sigmoid_beta
self.supervisor = supervisor
self.supervision_scalar = supervision_scalar
self.inspect_model = inspect_model


def compile(self, **kwargs):
Expand All @@ -663,13 +664,16 @@ def compile(self, **kwargs):
self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")
if self.supervisor is not None:
self.supervised_loss_tracker = keras.metrics.Mean(name="supervised_loss")
# self.kid = KID(name = "kid", input_shape = self.tensor_map.shape)
if self.input_map.axes() == 3 and self.inspect_model:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.input_map.shape, kernel_image_size=299)

@property
def metrics(self):
m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric]
if self.supervisor is not None:
m.append(self.supervised_loss_tracker)
if self.input_map.axes() == 3 and self.inspect_model:
m.append(self.kid)
return m

def denormalize(self, images):
Expand Down Expand Up @@ -871,14 +875,16 @@ def test_step(self, batch):

# measure KID between real and generated images
# this is computationally demanding, kid_diffusion_steps has to be small
images = self.denormalize(images)
generated_images = self.generate(
control_embed, num_images=self.batch_size, diffusion_steps=20,
)
# self.kid.update_state(images, generated_images)
if self.tensor_map.axes() == 3 and self.inspect_model:
images = self.denormalize(images)
generated_images = self.generate(
num_images=self.batch_size, diffusion_steps=20
)
self.kid.update_state(images, generated_images)

return {m.name: m.result() for m in self.metrics}


def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None, prefix='./figures/'):
control_batch = {}
for cm in self.output_maps:
Expand Down
14 changes: 9 additions & 5 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ def train_diffusion_control_model(args, supervised=False):
model = DiffusionController(
args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x,
args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss,
args.sigmoid_beta, args.diffusion_condition_strategy, supervised_model, args.supervision_scalar,
args.sigmoid_beta, args.diffusion_condition_strategy, args.inspect_model, supervised_model, args.supervision_scalar,
)
else:
model = DiffusionController(
args.tensor_maps_in[0], args.tensor_maps_out, args.batch_size, args.dense_blocks, args.block_size, args.conv_x,
args.dense_layers[0], args.attention_window, args.attention_heads, args.attention_modulo, args.diffusion_loss,
args.sigmoid_beta, args.diffusion_condition_strategy,
args.inspect_model, args.sigmoid_beta, args.diffusion_condition_strategy,
)

loss = keras.losses.mean_absolute_error if args.diffusion_loss == 'mean_absolute_error' else keras.losses.mean_squared_error
Expand Down Expand Up @@ -385,16 +385,20 @@ def train_diffusion_control_model(args, supervised=False):
metrics = model.evaluate(generate_test, batch_size=args.batch_size, steps=args.test_steps, return_dict=True)
logging.info(f'Test metrics: {metrics}')

data, labels, paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/')
steps = 1 if args.batch_size > 3 else args.test_steps
data, labels, paths = big_batch_from_minibatch_generator(generate_test, steps)
sides = int(np.sqrt(steps*args.batch_size))
preds = model.plot_reconstructions((data, labels), num_rows=sides, num_cols=sides,
prefix=f'{args.output_folder}/{args.id}/reconstructions/')

image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]}
predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/')
interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size,
f'{args.output_folder}/{args.id}/')
if model.input_map.axes() == 2:
model.plot_ecgs(num_rows=2, prefix=os.path.dirname(checkpoint_path))
else:
model.plot_images(num_rows=2, prefix=os.path.dirname(checkpoint_path))
model.plot_images(num_cols=sides, num_rows=sides, prefix=os.path.dirname(checkpoint_path))

for tm_out, model_file in zip(args.tensor_maps_out, args.model_files):
args.tensor_maps_out = [tm_out]
Expand Down

0 comments on commit e3a179c

Please sign in to comment.