diff --git a/framework/applications/utils/transforms.py b/framework/applications/utils/transforms.py index 4212749..d818967 100644 --- a/framework/applications/utils/transforms.py +++ b/framework/applications/utils/transforms.py @@ -90,10 +90,10 @@ def reset_parameters(self): def forward(self, input): torch_version_str = str(torch.__version__).split('.') - if int(torch_version_str[0]) >= 1 and int(torch_version_str[1]) > 7: - return self._conv_forward(input, self.weight_scaling * self.weight, self.bias) - else: + if int(torch_version_str[0]) < 1 or (int(torch_version_str[0]) == 1 and int(torch_version_str[1]) <= 7): return self._conv_forward(input, self.weight_scaling * self.weight) + else: + return self._conv_forward(input, self.weight_scaling * self.weight, self.bias) class ScaledLinear(nn.Linear): def __init__(self, in_features, out_features, *args, **kwargs): @@ -127,39 +127,15 @@ def update_linear(self, m, parent): lsa_update.weight, lsa_update.bias = m[1].weight, m[1].bias setattr(parent, m[0], lsa_update) + def add_lsa_params_recursive(self, module): + for name, child in module.named_children(): + if isinstance(child, nn.Conv2d) and child.weight.requires_grad: + self.update_conv2d((name, child), module) + elif isinstance(child, nn.Linear) and child.weight.requires_grad: + self.update_linear((name, child), module) + elif len(list(child.children())) > 0: + self.add_lsa_params_recursive(child) + def add_lsa_params(self): - ''' - adds LSA scaling parameters to conv and linear layers - - max. nested object depth: 4 - - trainable_true (i.e. does not add LSA params to layers which are not trained, e.g. in classifier only training) - ''' - for m in self.mdl.named_children(): - if isinstance(m[1], nn.Conv2d) and m[1].weight.requires_grad: - self.update_conv2d(m, self.mdl) - elif isinstance(m[1], nn.Linear) and m[1].weight.requires_grad: - self.update_linear(m, self.mdl) - elif len(dict(m[1].named_children())) > 0: - for n in m[1].named_children(): - if isinstance(n[1], nn.Conv2d) and n[1].weight.requires_grad: - self.update_conv2d(n, m[1]) - elif isinstance(n[1], nn.Linear) and n[1].weight.requires_grad: - self.update_linear(n, m[1]) - elif len(dict(n[1].named_children())) > 0: - for o in n[1].named_children(): - if isinstance(o[1], nn.Conv2d) and o[1].weight.requires_grad: - self.update_conv2d(o, n[1]) - elif isinstance(o[1], nn.Linear) and o[1].weight.requires_grad: - self.update_linear(o, n[1]) - elif len(dict(o[1].named_children())) > 0: - for p in o[1].named_children(): - if isinstance(p[1], nn.Conv2d) and p[1].weight.requires_grad: - self.update_conv2d(p, o[1]) - elif isinstance(p[1], nn.Linear) and p[1].weight.requires_grad: - self.update_linear(p, o[1]) - elif len(dict(p[1].named_children())) > 0: - for q in p[1].named_children(): - if isinstance(q[1], nn.Conv2d) and q[1].weight.requires_grad: - self.update_conv2d(q, p[1]) - elif isinstance(q[1], nn.Linear) and q[1].weight.requires_grad: - self.update_linear(q, p[1]) + self.add_lsa_params_recursive(self.mdl) return self.mdl \ No newline at end of file diff --git a/requirements_cu11.txt b/requirements_cu11.txt index 6345123..ead0272 100755 --- a/requirements_cu11.txt +++ b/requirements_cu11.txt @@ -1,4 +1,4 @@ ---extra-index-url https://download.pytorch.org/whl/cu113 +--extra-index-url https://download.pytorch.org/whl/cu118 Click>=7.0 scikit-learn>=0.23.1 tqdm>=4.32.2