diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 4da267824..fbb807a80 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -376,7 +376,7 @@ def train_diffusion_control_model(args, supervised=False): data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) preds = model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/reconstructions/') image_out = {args.tensor_maps_in[0].output_name(): data[args.tensor_maps_in[0].input_name()]} - predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/') + predictions_to_pngs(preds, args.tensor_maps_in, args.tensor_maps_in, data, image_out, paths, f'{args.output_folder}/{args.id}/reconstructions/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: