Skip to content

Commit

Permalink
Merge pull request #34 from microsoft/dev/setup-refactor
Browse files Browse the repository at this point in the history
Prepare for v1.1 release
  • Loading branch information
Lynazhang authored Oct 23, 2021
2 parents a478371 + 23bd901 commit bb4cb6c
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name: Integrated Test
on: [push]

jobs:
build-linux:
model-test:
runs-on: ubuntu-latest
strategy:
max-parallel: 5
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/integration-test_nni_based_torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name: Integrated Test for Torch Model Based on NNI
on: [push]

jobs:
build-linux:
torch-model-test:
runs-on: ubuntu-latest
strategy:
max-parallel: 5
Expand Down Expand Up @@ -42,7 +42,7 @@ jobs:
- name: Install nni
run: pip install nni==2.4

- name: Install nn-Meter
run: pip install -U .

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration-test_onnx_based_torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name: Integrated Test for Torch Model Based on ONNX
on: [push]

jobs:
build-linux:
torch-model-test:
runs-on: ubuntu-latest
strategy:
max-parallel: 5
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ nn-Meter is a latency predictor of models with type of Tensorflow, PyTorch, Onnx
| Testing Model Type | Requirements |
| :----------------: | :-----------------------------------------------------------------------------------------------------------------------: |
| Tensorflow | `tensorflow==1.15.0` |
| Torch | `torch==1.7.1`, `torchvision==0.8.2`, (alternative)[`onnx==1.9.0`, `onnx-simplifier==0.3.6`] or [`nni==2.4`][1] |
| Torch | `torch==1.9.0`, `torchvision==0.10.0`, (alternative)[`onnx==1.9.0`, `onnx-simplifier==0.3.6`] or [`nni>=2.4`][1] |
| Onnx | `onnx==1.9.0` |
| nn-Meter IR graph | --- |
| NNI IR graph | `nni==2.4` |
| NNI IR graph | `nni>=2.4` |

[1] Please refer to [nn-Meter Usage](#torch-model-converters) for more information.

Expand All @@ -73,7 +73,7 @@ Here is a summary of supported inputs of the two methods.
| Tensorflow | Checkpoint file dumped by `tf.saved_model()` and end with `.pb` | Checkpoint file dumped by `tf.saved_model` and end with `.pb` |
| Torch | Models in `torchvision.models` | Object of `torch.nn.Module` |
| Onnx | Checkpoint file dumped by `torch.onnx.export()` or `onnx.save()` and end with `.onnx` | Checkpoint file dumped by `onnx.save()` or model loaded by `onnx.load()` |
| nn-Meter IR graph | Json file in the format of[nn-Meter IR Graph](./docs/input_models.md#nnmeter-ir-graph) | `dict` object following the format of [nn-Meter IR Graph](./docs/input_models.md#nnmeter-ir-graph) |
| nn-Meter IR graph | Json file in the format of [nn-Meter IR Graph](./docs/input_models.md#nnmeter-ir-graph) | `dict` object following the format of [nn-Meter IR Graph](./docs/input_models.md#nnmeter-ir-graph) |
| NNI IR graph | - | NNI IR graph object |

In both methods, users could appoint predictor name and version to target a specific hardware platform (device). Currently, nn-Meter supports prediction on the following four configs:
Expand Down
4 changes: 2 additions & 2 deletions docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ nn-Meter is a latency predictor of models with type of Tensorflow, PyTorch, Onnx
| Testing Model Type | Requirements |
| :-------------------: | :------------------------------------------------: |
| Tensorflow | `tensorflow==1.15.0` |
| Torch | `torch==1.7.1`, `torchvision==0.8.2`, (alternative)[`onnx==1.9.0`, `onnx-simplifier==0.3.6`] or [`nni==2.4`][1] |
| Torch | `torch==1.9.0`, `torchvision==0.10.0`, (alternative)[`onnx==1.9.0`, `onnx-simplifier==0.3.6`] or [`nni>=2.4`][1] |
| Onnx | `onnx==1.9.0` |
| nn-Meter IR graph | --- |
| NNI IR graph | `nni==2.4` |
| NNI IR graph | `nni>=2.4` |

[1] Please refer to [nn-Meter Usage](usage.md#torch-model-converters) for more information.

Expand Down
4 changes: 3 additions & 1 deletion nn_meter/ir_converters/torch_converter/converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nn_meter.utils.utils import try_import_onnx, try_import_torch, try_import_onnxsim
from nn_meter.utils.utils import try_import_onnx, try_import_torch, try_import_onnxsim, try_import_nni
import tempfile
from nn_meter.ir_converters.onnx_converter import OnnxConverter

Expand All @@ -19,6 +19,7 @@ def _nchw_to_nhwc(shapes):

class NNIIRConverter:
def __init__(self, ir_model):
try_import_nni()
try:
from nni.retiarii.converter.utils import flatten_model_graph
self.ir_model = flatten_model_graph(ir_model)
Expand Down Expand Up @@ -100,6 +101,7 @@ def _remove_unshaped_nodes(self, graph):
class NNIBasedTorchConverter(NNIIRConverter):
def __init__(self, model, example_inputs):
torch = try_import_torch()
try_import_nni()
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.converter.graph_gen import GraphConverterWithShape

Expand Down
42 changes: 30 additions & 12 deletions nn_meter/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,52 @@ def download_from_url(urladdr, ppath):
progress_bar.close()
os.remove(file_name)

def try_import_onnx(require_version = "1.9.0"):
def try_import_onnx(require_version = ["1.9.0"]):
if isinstance(require_version, str):
require_version = [require_version]
try:
import onnx
if version.parse(onnx.__version__) != version.parse(require_version):
logging.warning(f'onnx=={onnx.__version__} is not well tested now, well tested version: onnx=={require_version}' )
if version.parse(onnx.__version__).release not in [version.parse(v).release for v in require_version]:
logging.warning(f'onnx=={onnx.__version__} is not well tested now, well tested version: onnx=={", ".join(require_version)}' )
return onnx
except ImportError:
logging.error(f'You have not install the onnx package, please install onnx=={require_version} and try again.')
logging.error(f'You have not install the onnx package, please install onnx=={require_version[0]} and try again.')
exit()

def try_import_torch(require_version = "1.7.1"):
def try_import_torch(require_version = ["1.9.0", "1.7.1"]):
if isinstance(require_version, str):
require_version = [require_version]
try:
import torch
if version.parse(torch.__version__) != version.parse(require_version):
logging.warning(f'torch=={torch.__version__} is not well tested now, well tested version: torch=={require_version}' )
if version.parse(torch.__version__).release not in [version.parse(v).release for v in require_version]:
logging.warning(f'torch=={torch.__version__} is not well tested now, well tested version: torch=={", ".join(require_version)}' )
return torch
except ImportError:
logging.error(f'You have not install the torch package, please install torch=={require_version} and try again.')
logging.error(f'You have not install the torch package, please install torch=={require_version[0]} and try again.')
exit()

def try_import_tensorflow(require_version = "1.15.0"):
def try_import_tensorflow(require_version = ["1.15.0"]):
if isinstance(require_version, str):
require_version = [require_version]
try:
import tensorflow
if version.parse(tensorflow.__version__) != version.parse(require_version):
logging.warning(f'tensorflow=={tensorflow.__version__} is not well tested now, well tested version: tensorflow=={require_version}' )
if version.parse(tensorflow.__version__).release not in [version.parse(v).release for v in require_version]:
logging.warning(f'tensorflow=={tensorflow.__version__} is not well tested now, well tested version: tensorflow=={", ".join(require_version)}' )
return tensorflow
except ImportError:
logging.error(f'You have not install the tensorflow package, please install tensorflow=={require_version} and try again.')
logging.error(f'You have not install the tensorflow package, please install tensorflow=={require_version[0]} and try again.')
exit()

def try_import_nni(require_version = ["2.4", "2.5"]):
if isinstance(require_version, str):
require_version = [require_version]
try:
import nni
if version.parse(nni.__version__).release not in [version.parse(v).release for v in require_version]:
logging.warning(f'nni=={nni.__version__} is not well tested now, well tested version: nni=={", ".join(require_version)}' )
return nni
except ImportError:
logging.error(f'You have not install the tensorflow package, please install tensorflow=={require_version[0]} and try again.')
exit()

def try_import_torchvision_models():
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def integration_test_onnx_based_torch(model_type, model_list, output_name = "tes
result = subprocess.check_output(['nn-meter', 'lat_pred', f'--torchvision'] + model_list + ['--predictor', f'{pred_name}', '--predictor-version', f'{pred_version}'])
runtime = time.time() - since
except NotImplementedError:
logging.error("Meets ERROR when checking --torchvision {model_string} --predictor {pred_name} --predictor-version {pred_version}")
logging.error("Meets ERROR when checking --torchvision {model_list} --predictor {pred_name} --predictor-version {pred_version}")

latency_list = parse_latency_info(result.decode('utf-8'))
for model, latency in latency_list:
Expand Down

0 comments on commit bb4cb6c

Please sign in to comment.