Skip to content

Commit

Permalink
Make the PIDON model loss optional
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorC committed Dec 26, 2022
1 parent a7ddade commit 0894933
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
19 changes: 13 additions & 6 deletions pararealml/operators/ml/pidon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Loss(NamedTuple):
diff_eq_loss: tf.Tensor
ic_loss: tf.Tensor
bc_losses: Optional[Tuple[tf.Tensor, tf.Tensor]]
model_loss: tf.Tensor
model_loss: Optional[tf.Tensor]
weighted_total_loss: tf.Tensor

def __str__(self):
Expand All @@ -27,7 +27,8 @@ def __str__(self):
f"; Dirichlet BC: {self.bc_losses[0]}; "
+ f"Neumann BC: {self.bc_losses[1]}"
)
string += f"; Model: {self.model_loss}"
if self.model_loss is not None:
string += f"; Model: {self.model_loss}"
return string

@classmethod
Expand All @@ -37,7 +38,7 @@ def construct(
diff_eq_loss: tf.Tensor,
ic_loss: tf.Tensor,
bc_losses: Optional[Tuple[tf.Tensor, tf.Tensor]],
model_loss: tf.Tensor,
model_loss: Optional[tf.Tensor],
diff_eq_loss_weights: Sequence[float],
ic_loss_weights: Sequence[float],
bc_loss_weights: Sequence[float],
Expand Down Expand Up @@ -66,7 +67,8 @@ def construct(
weighted_total_loss += tf.multiply(
tf.constant(bc_loss_weights), bc_losses[0] + bc_losses[1]
)
weighted_total_loss += model_loss
if model_loss is not None:
weighted_total_loss += model_loss
return Loss(
diff_eq_loss, ic_loss, bc_losses, model_loss, weighted_total_loss
)
Expand Down Expand Up @@ -103,7 +105,8 @@ def mean(
if loss.bc_losses:
dirichlet_bc_losses.append(loss.bc_losses[0])
neumann_bc_losses.append(loss.bc_losses[1])
model_losses.append(loss.model_loss)
if loss.model_loss is not None:
model_losses.append(loss.model_loss)

mean_diff_eq_loss = tf.reduce_mean(tf.stack(diff_eq_losses), axis=0)
mean_ic_loss = tf.reduce_mean(tf.stack(ic_losses), axis=0)
Expand All @@ -115,7 +118,11 @@ def mean(
tf.reduce_mean(tf.stack(neumann_bc_losses), axis=0),
)
)
mean_model_loss = tf.reduce_mean(tf.stack(model_losses), axis=0)
mean_model_loss = (
tf.reduce_mean(tf.stack(model_losses), axis=0)
if model_losses
else None
)

return cls.construct(
mean_diff_eq_loss,
Expand Down
16 changes: 5 additions & 11 deletions pararealml/operators/ml/pidon/pi_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,7 @@ def value_and_gradients_function(
) -> Tuple[tf.Tensor, tf.Tensor]:
self.set_trainable_parameters(parameters)
with AutoDifferentiator() as auto_diff:
loss = self._compute_physics_informed_loss(
full_training_data_batch, True
)
loss = self._compute_batch_loss(full_training_data_batch, True)
value = tf.reduce_sum(loss.weighted_total_loss, keepdims=True)

gradients = auto_diff.gradient(value, self.trainable_variables)
Expand Down Expand Up @@ -365,7 +363,7 @@ def _compute_total_loss(
:return: the mean physics-informed loss
"""
loss_function = (
partial(self._compute_physics_informed_loss, training=False)
partial(self._compute_batch_loss, training=False)
if optimizer is None
else partial(self._train, optimizer=optimizer)
)
Expand Down Expand Up @@ -399,7 +397,7 @@ def _train(
:return: the various losses over the batch
"""
with AutoDifferentiator() as auto_diff:
loss = self._compute_physics_informed_loss(batch, True)
loss = self._compute_batch_loss(batch, True)

optimizer.minimize(
loss.weighted_total_loss, self.trainable_variables, tape=auto_diff
Expand All @@ -408,9 +406,7 @@ def _train(
return loss

@tf.function
def _compute_physics_informed_loss(
self, batch: DataBatch, training: bool
) -> Loss:
def _compute_batch_loss(self, batch: DataBatch, training: bool) -> Loss:
"""
Computes and returns the total physics-informed loss over the batch
consisting of the mean squared differential equation error, the mean
Expand All @@ -433,9 +429,7 @@ def _compute_physics_informed_loss(
else None
)
model_loss = (
tf.reshape(tf.add_n(self.losses), (1,))
if self.losses
else tf.constant([0.0])
tf.reshape(tf.add_n(self.losses), (1,)) if self.losses else None
)

return Loss.construct(
Expand Down

0 comments on commit 0894933

Please sign in to comment.