From f2f5305008ffb15b3e981fedc7eae5f05208b4b3 Mon Sep 17 00:00:00 2001 From: Daniel Becking <56083075+d-becking@users.noreply.github.com> Date: Wed, 9 Aug 2023 10:08:39 +0200 Subject: [PATCH] LSA generalization through recursively equipping nn.Conv2d and nn.Linear modules with scaling parameters (independent of the nested object's depth); LSA compatibility with torch versions >= 2.x.x --- framework/applications/utils/transforms.py | 50 ++++++---------------- requirements_cu11.txt | 2 +- 2 files changed, 14 insertions(+), 38 deletions(-) diff --git a/framework/applications/utils/transforms.py b/framework/applications/utils/transforms.py index 4212749..d818967 100644 --- a/framework/applications/utils/transforms.py +++ b/framework/applications/utils/transforms.py @@ -90,10 +90,10 @@ def reset_parameters(self): def forward(self, input): torch_version_str = str(torch.__version__).split('.') - if int(torch_version_str[0]) >= 1 and int(torch_version_str[1]) > 7: - return self._conv_forward(input, self.weight_scaling * self.weight, self.bias) - else: + if int(torch_version_str[0]) < 1 or (int(torch_version_str[0]) == 1 and int(torch_version_str[1]) <= 7): return self._conv_forward(input, self.weight_scaling * self.weight) + else: + return self._conv_forward(input, self.weight_scaling * self.weight, self.bias) class ScaledLinear(nn.Linear): def __init__(self, in_features, out_features, *args, **kwargs): @@ -127,39 +127,15 @@ def update_linear(self, m, parent): lsa_update.weight, lsa_update.bias = m[1].weight, m[1].bias setattr(parent, m[0], lsa_update) + def add_lsa_params_recursive(self, module): + for name, child in module.named_children(): + if isinstance(child, nn.Conv2d) and child.weight.requires_grad: + self.update_conv2d((name, child), module) + elif isinstance(child, nn.Linear) and child.weight.requires_grad: + self.update_linear((name, child), module) + elif len(list(child.children())) > 0: + self.add_lsa_params_recursive(child) + def add_lsa_params(self): - ''' - adds LSA scaling parameters to conv and linear layers - - max. nested object depth: 4 - - trainable_true (i.e. does not add LSA params to layers which are not trained, e.g. in classifier only training) - ''' - for m in self.mdl.named_children(): - if isinstance(m[1], nn.Conv2d) and m[1].weight.requires_grad: - self.update_conv2d(m, self.mdl) - elif isinstance(m[1], nn.Linear) and m[1].weight.requires_grad: - self.update_linear(m, self.mdl) - elif len(dict(m[1].named_children())) > 0: - for n in m[1].named_children(): - if isinstance(n[1], nn.Conv2d) and n[1].weight.requires_grad: - self.update_conv2d(n, m[1]) - elif isinstance(n[1], nn.Linear) and n[1].weight.requires_grad: - self.update_linear(n, m[1]) - elif len(dict(n[1].named_children())) > 0: - for o in n[1].named_children(): - if isinstance(o[1], nn.Conv2d) and o[1].weight.requires_grad: - self.update_conv2d(o, n[1]) - elif isinstance(o[1], nn.Linear) and o[1].weight.requires_grad: - self.update_linear(o, n[1]) - elif len(dict(o[1].named_children())) > 0: - for p in o[1].named_children(): - if isinstance(p[1], nn.Conv2d) and p[1].weight.requires_grad: - self.update_conv2d(p, o[1]) - elif isinstance(p[1], nn.Linear) and p[1].weight.requires_grad: - self.update_linear(p, o[1]) - elif len(dict(p[1].named_children())) > 0: - for q in p[1].named_children(): - if isinstance(q[1], nn.Conv2d) and q[1].weight.requires_grad: - self.update_conv2d(q, p[1]) - elif isinstance(q[1], nn.Linear) and q[1].weight.requires_grad: - self.update_linear(q, p[1]) + self.add_lsa_params_recursive(self.mdl) return self.mdl \ No newline at end of file diff --git a/requirements_cu11.txt b/requirements_cu11.txt index 6345123..ead0272 100755 --- a/requirements_cu11.txt +++ b/requirements_cu11.txt @@ -1,4 +1,4 @@ ---extra-index-url https://download.pytorch.org/whl/cu113 +--extra-index-url https://download.pytorch.org/whl/cu118 Click>=7.0 scikit-learn>=0.23.1 tqdm>=4.32.2