diff --git a/layer_to_layer_pytorch/l2l.py b/layer_to_layer_pytorch/l2l.py index 36f7d76..495420b 100644 --- a/layer_to_layer_pytorch/l2l.py +++ b/layer_to_layer_pytorch/l2l.py @@ -57,6 +57,10 @@ def zero_grad(self) -> None: self._reset_activations() + def _zero_layer_grad(self, layer: nn.Module) -> None: + for param in layer.parameters(): + param.grad = None + @torch.no_grad() def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor: layers: nn.ModuleList = self._get_layers() @@ -119,9 +123,8 @@ def backward( total=self.num_layers, leave=False, ): + self._zero_layer_grad(l) layer: nn.Module = copy.deepcopy(l).to(self.gpu_device) - for param in layer.parameters(): - param.grad = None f_idx: int = self.num_layers - idx - 1 # TODO: preserve re-calculations