Skip to content

Commit

Permalink
Rename some instance methods of PIDeepONet
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorC committed Jan 1, 2023
1 parent 0894933 commit d95fce6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
40 changes: 22 additions & 18 deletions pararealml/operators/ml/pidon/pi_deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def boundary_condition_loss_weights(self) -> Sequence[float]:
"""
return self._bc_loss_weights

def fit(
def train(
self,
epochs: int,
optimizer: Union[str, Dict[str, Any], tf.optimizers.Optimizer],
Expand All @@ -165,10 +165,10 @@ def fit(
restore_best_weights: bool = True,
) -> Tuple[List[Loss], Optional[List[Loss]]]:
"""
Fits the branch and trunk net parameters by minimising the
physics-informed loss function over the provided training data set. It
also evaluates the loss over both the training data and the test data,
if provided, for every epoch.
Fits the model by minimising the physics-informed loss function over
the provided training data set with respect to the parameters of the
branch, trunk, and combiner networks. It also evaluates the loss over
both the training data and the test data, if provided, for every epoch.
:param epochs: the number of epochs over the training data
:param optimizer: the optimizer to use to minimize the loss function
Expand Down Expand Up @@ -228,7 +228,7 @@ def fit(

return training_loss_history, test_loss_history

def fit_with_lbfgs(
def train_with_lbfgs(
self,
training_data: DataSetIterator,
max_iterations: int,
Expand All @@ -238,9 +238,10 @@ def fit_with_lbfgs(
gradient_tol: float,
):
"""
Fits the branch and trunk net parameters by minimising the
physics-informed loss function over the provided training data set
using the L-BFGS optimization method.
Fits the model by minimising the physics-informed loss function over
the provided training data set with respect to the parameters of the
branch, trunk, and combiner networks using the L-BFGS optimization
method.
:param training_data: the data set providing the full training batch
:param max_iterations: the maximum number of iterations to perform the
Expand Down Expand Up @@ -360,12 +361,14 @@ def _compute_total_loss(
:param data: the data set to compute the loss over
:param optimizer: an optional optimizer instance; if one is provided,
the model parameters are updated after each batch
:return: the mean physics-informed loss
:return: the mean loss over the data set
"""
loss_function = (
partial(self._compute_batch_loss, training=False)
if optimizer is None
else partial(self._train, optimizer=optimizer)
else partial(
self._compute_and_minimize_batch_loss, optimizer=optimizer
)
)

batch_losses = []
Expand All @@ -384,17 +387,17 @@ def _compute_total_loss(
)

@tf.function
def _train(
def _compute_and_minimize_batch_loss(
self, batch: DataBatch, optimizer: tf.optimizers.Optimizer
) -> Loss:
"""
Performs a forward pass on the batch, computes the batch loss, and
Performs a forward pass over the batch, computes the batch loss, and
updates the model parameters.
:param batch: the batch to compute the losses over
:param optimizer: the optimizer to use to update parameters of the
model
:return: the various losses over the batch
:return: the mean loss over the batch
"""
with AutoDifferentiator() as auto_diff:
loss = self._compute_batch_loss(batch, True)
Expand All @@ -408,15 +411,16 @@ def _train(
@tf.function
def _compute_batch_loss(self, batch: DataBatch, training: bool) -> Loss:
"""
Computes and returns the total physics-informed loss over the batch
Computes and returns the physics-informed loss over the batch
consisting of the mean squared differential equation error, the mean
squared initial condition error, and in the case of PDEs, the mean
squared Dirichlet and Neumann boundary condition errors.
squared initial condition error, in the case of PDEs, the mean
squared Dirichlet and Neumann boundary condition errors, and if
applicable, the regularization error.
:param batch: the batch to compute the losses over
:param training: whether to call the underlying DeepONet in training
mode
:return: the total physics-informed loss over the batch
:return: the mean loss over the batch
"""
domain_batch, initial_batch, boundary_batch = batch
diff_eq_loss = self._compute_differential_equation_loss(
Expand Down
4 changes: 2 additions & 2 deletions pararealml/operators/ml/pidon/pidon_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,14 @@ def train(
)
)

training_loss_history, test_loss_history = model.fit(
training_loss_history, test_loss_history = model.train(
training_data=training_data,
test_data=test_data,
**optimization_args._asdict(),
)

if secondary_optimization_args:
model.fit_with_lbfgs(
model.train_with_lbfgs(
training_data=training_data,
**secondary_optimization_args._asdict(),
)
Expand Down

0 comments on commit d95fce6

Please sign in to comment.