From 590e068608375095dfb6671b08e73710c9e95e15 Mon Sep 17 00:00:00 2001 From: Daniel Becking <56083075+d-becking@users.noreply.github.com> Date: Wed, 9 Aug 2023 10:08:39 +0200 Subject: [PATCH 1/2] LSA generalization through recursively equipping nn.Conv2d and nn.Linear modules with scaling parameters (independent of the nested object's depth); LSA compatibility with torch versions >= 2.x.x --- framework/applications/utils/transforms.py | 50 ++++++---------------- 1 file changed, 13 insertions(+), 37 deletions(-) diff --git a/framework/applications/utils/transforms.py b/framework/applications/utils/transforms.py index 4212749..d818967 100644 --- a/framework/applications/utils/transforms.py +++ b/framework/applications/utils/transforms.py @@ -90,10 +90,10 @@ def reset_parameters(self): def forward(self, input): torch_version_str = str(torch.__version__).split('.') - if int(torch_version_str[0]) >= 1 and int(torch_version_str[1]) > 7: - return self._conv_forward(input, self.weight_scaling * self.weight, self.bias) - else: + if int(torch_version_str[0]) < 1 or (int(torch_version_str[0]) == 1 and int(torch_version_str[1]) <= 7): return self._conv_forward(input, self.weight_scaling * self.weight) + else: + return self._conv_forward(input, self.weight_scaling * self.weight, self.bias) class ScaledLinear(nn.Linear): def __init__(self, in_features, out_features, *args, **kwargs): @@ -127,39 +127,15 @@ def update_linear(self, m, parent): lsa_update.weight, lsa_update.bias = m[1].weight, m[1].bias setattr(parent, m[0], lsa_update) + def add_lsa_params_recursive(self, module): + for name, child in module.named_children(): + if isinstance(child, nn.Conv2d) and child.weight.requires_grad: + self.update_conv2d((name, child), module) + elif isinstance(child, nn.Linear) and child.weight.requires_grad: + self.update_linear((name, child), module) + elif len(list(child.children())) > 0: + self.add_lsa_params_recursive(child) + def add_lsa_params(self): - ''' - adds LSA scaling parameters to conv and linear layers - - max. nested object depth: 4 - - trainable_true (i.e. does not add LSA params to layers which are not trained, e.g. in classifier only training) - ''' - for m in self.mdl.named_children(): - if isinstance(m[1], nn.Conv2d) and m[1].weight.requires_grad: - self.update_conv2d(m, self.mdl) - elif isinstance(m[1], nn.Linear) and m[1].weight.requires_grad: - self.update_linear(m, self.mdl) - elif len(dict(m[1].named_children())) > 0: - for n in m[1].named_children(): - if isinstance(n[1], nn.Conv2d) and n[1].weight.requires_grad: - self.update_conv2d(n, m[1]) - elif isinstance(n[1], nn.Linear) and n[1].weight.requires_grad: - self.update_linear(n, m[1]) - elif len(dict(n[1].named_children())) > 0: - for o in n[1].named_children(): - if isinstance(o[1], nn.Conv2d) and o[1].weight.requires_grad: - self.update_conv2d(o, n[1]) - elif isinstance(o[1], nn.Linear) and o[1].weight.requires_grad: - self.update_linear(o, n[1]) - elif len(dict(o[1].named_children())) > 0: - for p in o[1].named_children(): - if isinstance(p[1], nn.Conv2d) and p[1].weight.requires_grad: - self.update_conv2d(p, o[1]) - elif isinstance(p[1], nn.Linear) and p[1].weight.requires_grad: - self.update_linear(p, o[1]) - elif len(dict(p[1].named_children())) > 0: - for q in p[1].named_children(): - if isinstance(q[1], nn.Conv2d) and q[1].weight.requires_grad: - self.update_conv2d(q, p[1]) - elif isinstance(q[1], nn.Linear) and q[1].weight.requires_grad: - self.update_linear(q, p[1]) + self.add_lsa_params_recursive(self.mdl) return self.mdl \ No newline at end of file From fdec546f48191fff3154606a7d1342a2a8251f4b Mon Sep 17 00:00:00 2001 From: Paul Haase Date: Tue, 14 Nov 2023 13:28:17 +0100 Subject: [PATCH 2/2] Improvements and bugfixes: - Fixed a bug in PyTorch-model loading - Improved loading and handling of TensorFlow models - Added support for Metal Performance Shaders (MPS) - Added requirement file for virtual python environment on mac os --- README.md | 2 +- create_venv_macos.sh | 8 ++++++ framework/pytorch_model/__init__.py | 14 ++++++---- framework/tensorflow_model/__init__.py | 33 +++++++++++----------- framework/use_case_init/__init__.py | 38 ++++++++++++++------------ requirements_cu11.txt | 2 +- requirements_macos.txt | 12 ++++++++ setup.py | 2 +- 8 files changed, 69 insertions(+), 42 deletions(-) create mode 100755 create_venv_macos.sh create mode 100755 requirements_macos.txt diff --git a/README.md b/README.md index 5dcc643..4db6ee6 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ source env/bin/activate **Note**: For further information on how to set up a virtual python environment (also on **Windows**) refer to https://docs.python.org/3/library/venv.html . -When successfully installed, the software outputs the line : "Successfully installed NNC-0.2.2" +When successfully installed, the software outputs the line : "Successfully installed NNC-0.3.0" ### Importing the main module diff --git a/create_venv_macos.sh b/create_venv_macos.sh new file mode 100755 index 0000000..af8ff6b --- /dev/null +++ b/create_venv_macos.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +python3 -m venv env +source env/bin/activate +pip install --upgrade pip +pip install -r requirements_macos.txt +pip install -e . +deactivate \ No newline at end of file diff --git a/framework/pytorch_model/__init__.py b/framework/pytorch_model/__init__.py index 4f45a59..a696db8 100644 --- a/framework/pytorch_model/__init__.py +++ b/framework/pytorch_model/__init__.py @@ -482,7 +482,14 @@ def __init__(self, lr=1e-4, ): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + self.device = torch.device(device) torch.manual_seed(451) torch.backends.cudnn.deterministic = True @@ -491,7 +498,6 @@ def __init__(self, self.learning_rate = lr self.epochs = epochs self.max_batches = max_batches - self.handle = handler if test_set: self.test_set = test_set @@ -517,7 +523,6 @@ def test_model(self, verbose=False ): - torch.set_num_threads(1) Model = copy.deepcopy(self.model) base_model_arch = Model.state_dict() @@ -556,7 +561,6 @@ def eval_model(self, verbose=False ): - torch.set_num_threads(1) Model = copy.deepcopy(self.model) @@ -595,7 +599,7 @@ def tune_model( ft_flag=False, verbose=False, ): - torch.set_num_threads(1) + verbose = 1 if (verbose & 1) else 0 base_model_arch = self.model.state_dict() diff --git a/framework/tensorflow_model/__init__.py b/framework/tensorflow_model/__init__.py index dceffa3..363e337 100644 --- a/framework/tensorflow_model/__init__.py +++ b/framework/tensorflow_model/__init__.py @@ -131,6 +131,8 @@ def create_NNC_model_instance_from_file( model_struct = loaded_model_struct if dataset_path and model_struct: + if model_name == None and hasattr(model_struct, 'name'): + model_name=model_struct.name TEFModelExecuter = create_imagenet_model_executer(model_struct=model_struct, dataset_path=dataset_path, lr=lr, @@ -161,6 +163,8 @@ def create_NNC_model_instance_from_object( model_struct = loaded_model_struct if dataset_path and model_struct: + if model_name == None and hasattr(model_struct, 'name'): + model_name=model_struct.name TEFModelExecuter = create_imagenet_model_executer(model_struct=model_struct, dataset_path=dataset_path, lr=lr, @@ -230,6 +234,10 @@ def __init__(self, model_dict=None): def load_model( self, model_path ): + + try: + model_file = tf.keras.models.load_model(model_path) + except: model_file = h5py.File(model_path, 'r') try: @@ -262,26 +270,19 @@ def init_model_from_model_object( self, model_object, ): self.model = model_object - - h5_model_path = './temp.h5' - model_object.save_weights(h5_model_path) - model = h5py.File(h5_model_path, 'r') - os.remove(h5_model_path) - if 'layer_names' in model.attrs: - module_names = [n for n in model.attrs['layer_names']] - + weights = model_object.get_weights() layer_names = [] - for mod_name in module_names: - layer = model[mod_name] - if 'weight_names' in layer.attrs: - weight_names = [mod_name+'/'+n for n in layer.attrs['weight_names']] - if weight_names: - layer_names += weight_names + + for layer in model_object.layers: + mod_name = layer.name + if layer.weights != []: + for weight in layer.weights: + layer_names.append(mod_name+"/"+weight.name) model_parameter_dict = {} - for name in layer_names: - model_parameter_dict[name] = model[name] + for i, name in enumerate(layer_names): + model_parameter_dict[name] = weights[i] return self.init_model_from_dict( model_parameter_dict ), model_object diff --git a/framework/use_case_init/__init__.py b/framework/use_case_init/__init__.py index 628e03e..50986f6 100644 --- a/framework/use_case_init/__init__.py +++ b/framework/use_case_init/__init__.py @@ -148,47 +148,49 @@ def preprocess( ): image_size = 224 - if self.__model_name == 'EfficientNetB1': + if self.__model_name == 'EfficientNetB1' or self.__model_name == 'efficientnetb1': image_size = 240 - elif self.__model_name == 'EfficientNetB2': + elif self.__model_name == 'EfficientNetB2' or self.__model_name == 'efficientnetb2': image_size = 260 - elif self.__model_name == 'EfficientNetB3': + elif self.__model_name == 'EfficientNetB3' or self.__model_name == 'efficientnetb3': image_size = 300 - elif self.__model_name == 'EfficientNetB4': + elif self.__model_name == 'EfficientNetB4' or self.__model_name == 'efficientnetb4': image_size = 380 - elif self.__model_name == 'EfficientNetB5': + elif self.__model_name == 'EfficientNetB5' or self.__model_name == 'efficientnetb5': image_size = 456 - elif self.__model_name == 'EfficientNetB6': + elif self.__model_name == 'EfficientNetB6' or self.__model_name == 'efficientnetb6': image_size = 528 - elif self.__model_name == 'EfficientNetB7': + elif self.__model_name == 'EfficientNetB7' or self.__model_name == 'efficientnetb7': image_size = 600 image, label = self.model_transform(image, label, image_size=image_size) - if 'DenseNet' in self.__model_name: + if 'DenseNet' in self.__model_name or 'densenet' in self.__model_name: return tf.keras.applications.densenet.preprocess_input(image), label - elif 'EfficientNet' in self.__model_name: + elif 'EfficientNet' in self.__model_name or 'efficientnet' in self.__model_name: return tf.keras.applications.efficientnet.preprocess_input(image), label - elif self.__model_name == 'InceptionResNetV2': + elif self.__model_name == 'InceptionResNetV2' or self.__model_name == 'inception_resnet_v2': return tf.keras.applications.inception_resnet_v2.preprocess_input(image), label - elif self.__model_name == 'InceptionV3': + elif self.__model_name == 'InceptionV3' or self.__model_name == "inception_v3": return tf.keras.applications.inception_v3.preprocess_input(image), label - elif self.__model_name == 'MobileNet': + elif self.__model_name == 'MobileNet' or ( 'mobilenet' in self.__model_name and 'v2' not in self.__model_name ): return tf.keras.applications.mobilenet.preprocess_input(image), label - elif self.__model_name == 'MobileNetV2': + elif self.__model_name == 'MobileNetV2' or 'mobilenetv2' in self.__model_name: return tf.keras.applications.mobilenet_v2.preprocess_input(image), label elif 'NASNet' in self.__model_name: return tf.keras.applications.nasnet.preprocess_input(image), label - elif 'ResNet' in self.__model_name and 'V2' not in self.__model_name: + elif ('ResNet' in self.__model_name and 'V2' not in self.__model_name) or ('resnet' in self.__model_name and 'v2' not in self.__model_name): return tf.keras.applications.resnet.preprocess_input(image), label - elif 'ResNet' in self.__model_name and 'V2' in self.__model_name: + elif ('ResNet' in self.__model_name and 'V2' in self.__model_name) or ('resnet' in self.__model_name and 'v2' in self.__model_name): return tf.keras.applications.resnet_v2.preprocess_input(image), label - elif self.__model_name == 'VGG16': + elif self.__model_name == 'VGG16' or self.__model_name == 'vgg16': return tf.keras.applications.vgg16.preprocess_input(image), label - elif self.__model_name == 'VGG19': + elif self.__model_name == 'VGG19' or self.__model_name == 'vgg19': return tf.keras.applications.vgg19.preprocess_input(image), label - elif self.__model_name == 'Xception': + elif self.__model_name == 'Xception' or self.__model_name == 'xception': return tf.keras.applications.xception.preprocess_input(image), label + elif 'RegNet' in self.__model_name or 'regnet' in self.__model_name: + return tf.keras.applications.regnet.preprocess_input(image), label # supported use cases diff --git a/requirements_cu11.txt b/requirements_cu11.txt index 6345123..55742b0 100755 --- a/requirements_cu11.txt +++ b/requirements_cu11.txt @@ -4,7 +4,7 @@ scikit-learn>=0.23.1 tqdm>=4.32.2 h5py>=3.1.0 pybind11>=2.6.2 -tensorflow>=2.6.0 +tensorflow[and-cuda]>=2.6.0 pandas>=1.0.5 opencv-python>=4.4.0.46 torch>=1.8.1 diff --git a/requirements_macos.txt b/requirements_macos.txt new file mode 100755 index 0000000..3886c2a --- /dev/null +++ b/requirements_macos.txt @@ -0,0 +1,12 @@ +urllib3<2.0 +Click>=7.0 +scikit-learn>=0.23.1 +tqdm>=4.32.2 +h5py>=3.1.0 +pybind11>=2.6.2 +tensorflow>=2.13.0 +tensorflow-metal>=1.0.0 +pandas>=1.0.5 +opencv-python>=4.4.0.46 +torch>=1.12.0 +torchvision>=0.13.1 diff --git a/setup.py b/setup.py index 10fda2a..c9683c9 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ from setuptools.command.build_ext import build_ext import setuptools -__version__ = '0.2.2' +__version__ = '0.3.0' class get_pybind_include(object):