Skip to content

Commit

Permalink
write loss derivative measurement.
Browse files Browse the repository at this point in the history
  • Loading branch information
KonstiNik committed May 21, 2024
1 parent b8807e6 commit a802a6d
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions papyrus/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,64 @@ def apply(self, ntk: np.ndarray) -> np.ndarray:
The Neural Tangent Kernel (NTK) matrix.
"""
return ntk


class LossDerivative(BaseMeasurement):
"""
Measurement class to record the derivative of the loss with respect to the neural
network outputs.
Neural State Keys
-----------------
loss_derivative : np.ndarray
The derivative of the loss with respect to the weights.
"""

def __init__(
self,
apply_fn: Callable,
name: str = "loss_derivative",
rank: int = 1,
public: bool = False,
):
"""
Constructor method of the LossDerivative class.
Parameters
----------
apply_fn : Callable
The function to compute the derivative of the loss with respect to the
neural network outputs.
name : str (default="loss_derivative")
The name of the measurement, defining how the instance in the database
will be identified.
rank : int (default=1)
The rank of the measurement, defining the tensor order of the
measurement.
public : bool (default=False)
Boolean flag to indicate whether the measurement resutls will be
accessible via a public attribute of the recorder.
"""
super().__init__(name, rank, public)
self.apply_fn = apply_fn

def apply(self, predictions: np.ndarray, targets: np.ndarray) -> np.ndarray:
"""
Method to record the derivative of the loss with respect to the neural network
outputs.
Parameters need to be provided as keyword arguments.
Parameters
----------
predictions : np.ndarray
The predictions of the neural network.
targets : np.ndarray
The target values of the neural network.
Returns
-------
np.ndarray
The derivative of the loss with respect to the neural network outputs.
"""
return self.apply_fn(predictions, targets)

0 comments on commit a802a6d

Please sign in to comment.