Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 8, 2025
1 parent 7827909 commit 3229037
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def train_step(self, batch):
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

# KID is not measured during the training phase for computational efficiency
return {m.name: m.result() for m in self.metrics[:-1]}
return {m.name: m.result() for m in self.metrics}

# def call(self, inputs):
# # normalize images to have standard deviation of 1, like the noises
Expand Down
4 changes: 2 additions & 2 deletions ml4h/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu
)

for k in sorted(history.history.keys()):
if not k.startswith("val_"):
if not k.startswith("val_") or k == 'val_supervised_loss':
if isinstance(history.history[k][0], LearningRateSchedule):
history.history[k] = [
history.history[k][0](i * training_steps)
Expand Down Expand Up @@ -470,7 +470,7 @@ def plot_metric_history(history, training_steps: int, title: str, prefix="./figu
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path)
for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss']:
for log_label in ['loss', 'val_loss', 'n_loss', 'val_n_loss', 'val_supervised_loss']:
if log_label not in history.history:
continue
logging.info(
Expand Down

0 comments on commit 3229037

Please sign in to comment.