Skip to content

Commit

Permalink
🔧 Update
Browse files Browse the repository at this point in the history
  • Loading branch information
TezRomacH committed Sep 15, 2020
1 parent 582bb33 commit 0dc6691
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions layer_to_layer_pytorch/l2l.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def zero_grad(self) -> None:

self._reset_activations()

def _zero_layer_grad(self, layer: nn.Module) -> None:
for param in layer.parameters():
param.grad = None

@torch.no_grad()
def forward(self, batch: torch.Tensor, **kwargs) -> torch.Tensor:
layers: nn.ModuleList = self._get_layers()
Expand Down Expand Up @@ -119,9 +123,8 @@ def backward(
total=self.num_layers,
leave=False,
):
self._zero_layer_grad(l)
layer: nn.Module = copy.deepcopy(l).to(self.gpu_device)
for param in layer.parameters():
param.grad = None
f_idx: int = self.num_layers - idx - 1

# TODO: preserve re-calculations
Expand Down

0 comments on commit 0dc6691

Please sign in to comment.