Skip to content

Commit

Permalink
Fixed incorrect decoding of 1-dimensional tensors of size 1
Browse files Browse the repository at this point in the history
  • Loading branch information
phaase-hhi committed Jul 7, 2023
1 parent 44fca93 commit a42cb8f
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 7 deletions.
1 change: 1 addition & 0 deletions framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@
assert sys.version_info >= (3, 6)

from . import pytorch_model
from . import tensorflow_model

2 changes: 0 additions & 2 deletions framework/pytorch_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,6 @@ def save_to_pytorch_file( model_data, path ):
model_dict = OrderedDict()
for module_name in model_data:
model_dict[module_name] = torch.tensor(model_data[module_name])
if model_data[module_name].size == 1:
model_dict[module_name] = torch.tensor(np.int64(model_data[module_name][0]))
torch.save(model_dict, path)


Expand Down
6 changes: 1 addition & 5 deletions framework/tensorflow_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import h5py
import os
import tensorflow as tf
from tensorflow import keras

import copy, logging
import numpy as np
Expand All @@ -67,10 +66,7 @@ def save_to_tensorflow_file( model_data, path ):
grp_name = (splits[0] + '/' + splits[1])
if grp_name not in grp_names:
grp_names.append(grp_name)
if model_data[module_name].size != 1:
h5_model.create_dataset(module_name, data=model_data[module_name])
else: #scalar
h5_model.create_dataset(module_name, data=np.int64(model_data[module_name][0]))
h5_model.create_dataset(module_name, data=model_data[module_name])
h5_model.attrs['layer_names'] = grp_names

for grp in h5_model:
Expand Down

0 comments on commit a42cb8f

Please sign in to comment.