diff --git a/pararealml/operators/ml/pidon/pi_deeponet.py b/pararealml/operators/ml/pidon/pi_deeponet.py index ad5b3f8..27b7b16 100644 --- a/pararealml/operators/ml/pidon/pi_deeponet.py +++ b/pararealml/operators/ml/pidon/pi_deeponet.py @@ -156,7 +156,7 @@ def boundary_condition_loss_weights(self) -> Sequence[float]: """ return self._bc_loss_weights - def fit( + def train( self, epochs: int, optimizer: Union[str, Dict[str, Any], tf.optimizers.Optimizer], @@ -165,10 +165,10 @@ def fit( restore_best_weights: bool = True, ) -> Tuple[List[Loss], Optional[List[Loss]]]: """ - Fits the branch and trunk net parameters by minimising the - physics-informed loss function over the provided training data set. It - also evaluates the loss over both the training data and the test data, - if provided, for every epoch. + Fits the model by minimising the physics-informed loss function over + the provided training data set with respect to the parameters of the + branch, trunk, and combiner networks. It also evaluates the loss over + both the training data and the test data, if provided, for every epoch. :param epochs: the number of epochs over the training data :param optimizer: the optimizer to use to minimize the loss function @@ -228,7 +228,7 @@ def fit( return training_loss_history, test_loss_history - def fit_with_lbfgs( + def train_with_lbfgs( self, training_data: DataSetIterator, max_iterations: int, @@ -238,9 +238,10 @@ def fit_with_lbfgs( gradient_tol: float, ): """ - Fits the branch and trunk net parameters by minimising the - physics-informed loss function over the provided training data set - using the L-BFGS optimization method. + Fits the model by minimising the physics-informed loss function over + the provided training data set with respect to the parameters of the + branch, trunk, and combiner networks using the L-BFGS optimization + method. :param training_data: the data set providing the full training batch :param max_iterations: the maximum number of iterations to perform the @@ -360,12 +361,14 @@ def _compute_total_loss( :param data: the data set to compute the loss over :param optimizer: an optional optimizer instance; if one is provided, the model parameters are updated after each batch - :return: the mean physics-informed loss + :return: the mean loss over the data set """ loss_function = ( partial(self._compute_batch_loss, training=False) if optimizer is None - else partial(self._train, optimizer=optimizer) + else partial( + self._compute_and_minimize_batch_loss, optimizer=optimizer + ) ) batch_losses = [] @@ -384,17 +387,17 @@ def _compute_total_loss( ) @tf.function - def _train( + def _compute_and_minimize_batch_loss( self, batch: DataBatch, optimizer: tf.optimizers.Optimizer ) -> Loss: """ - Performs a forward pass on the batch, computes the batch loss, and + Performs a forward pass over the batch, computes the batch loss, and updates the model parameters. :param batch: the batch to compute the losses over :param optimizer: the optimizer to use to update parameters of the model - :return: the various losses over the batch + :return: the mean loss over the batch """ with AutoDifferentiator() as auto_diff: loss = self._compute_batch_loss(batch, True) @@ -408,15 +411,16 @@ def _train( @tf.function def _compute_batch_loss(self, batch: DataBatch, training: bool) -> Loss: """ - Computes and returns the total physics-informed loss over the batch + Computes and returns the physics-informed loss over the batch consisting of the mean squared differential equation error, the mean - squared initial condition error, and in the case of PDEs, the mean - squared Dirichlet and Neumann boundary condition errors. + squared initial condition error, in the case of PDEs, the mean + squared Dirichlet and Neumann boundary condition errors, and if + applicable, the regularization error. :param batch: the batch to compute the losses over :param training: whether to call the underlying DeepONet in training mode - :return: the total physics-informed loss over the batch + :return: the mean loss over the batch """ domain_batch, initial_batch, boundary_batch = batch diff_eq_loss = self._compute_differential_equation_loss( diff --git a/pararealml/operators/ml/pidon/pidon_operator.py b/pararealml/operators/ml/pidon/pidon_operator.py index 8eab76a..21fad5d 100644 --- a/pararealml/operators/ml/pidon/pidon_operator.py +++ b/pararealml/operators/ml/pidon/pidon_operator.py @@ -245,14 +245,14 @@ def train( ) ) - training_loss_history, test_loss_history = model.fit( + training_loss_history, test_loss_history = model.train( training_data=training_data, test_data=test_data, **optimization_args._asdict(), ) if secondary_optimization_args: - model.fit_with_lbfgs( + model.train_with_lbfgs( training_data=training_data, **secondary_optimization_args._asdict(), )