Skip to content

Commit

Permalink
LSA generalization through recursively equipping nn.Conv2d and nn.Lin…
Browse files Browse the repository at this point in the history
…ear modules with scaling parameters (independent of the nested object's depth); LSA compatibility with torch versions >= 2.x.x
  • Loading branch information
d-becking authored and phaase-hhi committed Jan 15, 2024
1 parent d870292 commit f2f5305
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 38 deletions.
50 changes: 13 additions & 37 deletions framework/applications/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def reset_parameters(self):

def forward(self, input):
torch_version_str = str(torch.__version__).split('.')
if int(torch_version_str[0]) >= 1 and int(torch_version_str[1]) > 7:
return self._conv_forward(input, self.weight_scaling * self.weight, self.bias)
else:
if int(torch_version_str[0]) < 1 or (int(torch_version_str[0]) == 1 and int(torch_version_str[1]) <= 7):
return self._conv_forward(input, self.weight_scaling * self.weight)
else:
return self._conv_forward(input, self.weight_scaling * self.weight, self.bias)

class ScaledLinear(nn.Linear):
def __init__(self, in_features, out_features, *args, **kwargs):
Expand Down Expand Up @@ -127,39 +127,15 @@ def update_linear(self, m, parent):
lsa_update.weight, lsa_update.bias = m[1].weight, m[1].bias
setattr(parent, m[0], lsa_update)

def add_lsa_params_recursive(self, module):
for name, child in module.named_children():
if isinstance(child, nn.Conv2d) and child.weight.requires_grad:
self.update_conv2d((name, child), module)
elif isinstance(child, nn.Linear) and child.weight.requires_grad:
self.update_linear((name, child), module)
elif len(list(child.children())) > 0:
self.add_lsa_params_recursive(child)

def add_lsa_params(self):
'''
adds LSA scaling parameters to conv and linear layers
- max. nested object depth: 4
- trainable_true (i.e. does not add LSA params to layers which are not trained, e.g. in classifier only training)
'''
for m in self.mdl.named_children():
if isinstance(m[1], nn.Conv2d) and m[1].weight.requires_grad:
self.update_conv2d(m, self.mdl)
elif isinstance(m[1], nn.Linear) and m[1].weight.requires_grad:
self.update_linear(m, self.mdl)
elif len(dict(m[1].named_children())) > 0:
for n in m[1].named_children():
if isinstance(n[1], nn.Conv2d) and n[1].weight.requires_grad:
self.update_conv2d(n, m[1])
elif isinstance(n[1], nn.Linear) and n[1].weight.requires_grad:
self.update_linear(n, m[1])
elif len(dict(n[1].named_children())) > 0:
for o in n[1].named_children():
if isinstance(o[1], nn.Conv2d) and o[1].weight.requires_grad:
self.update_conv2d(o, n[1])
elif isinstance(o[1], nn.Linear) and o[1].weight.requires_grad:
self.update_linear(o, n[1])
elif len(dict(o[1].named_children())) > 0:
for p in o[1].named_children():
if isinstance(p[1], nn.Conv2d) and p[1].weight.requires_grad:
self.update_conv2d(p, o[1])
elif isinstance(p[1], nn.Linear) and p[1].weight.requires_grad:
self.update_linear(p, o[1])
elif len(dict(p[1].named_children())) > 0:
for q in p[1].named_children():
if isinstance(q[1], nn.Conv2d) and q[1].weight.requires_grad:
self.update_conv2d(q, p[1])
elif isinstance(q[1], nn.Linear) and q[1].weight.requires_grad:
self.update_linear(q, p[1])
self.add_lsa_params_recursive(self.mdl)
return self.mdl
2 changes: 1 addition & 1 deletion requirements_cu11.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu113
--extra-index-url https://download.pytorch.org/whl/cu118
Click>=7.0
scikit-learn>=0.23.1
tqdm>=4.32.2
Expand Down

0 comments on commit f2f5305

Please sign in to comment.