Skip to content

Commit

Permalink
Fixed netron visualization (#2567)
Browse files Browse the repository at this point in the history
### Changes

This PR solves #2552 by fixing the Netron visualization. The function
`save_for_netron` now produces an XML file that can be correctly opened
by Netron. To achieve this, a dummy `dtype` conversion has been
introduced, as discussed in #2552. This conversion maps the nncf dtype
`Dtype.FLOAT` to `f32` and `DType.INTEGER` to `i32`.

The `precision` parameter of the class `PortDesc` now is no longer
optional as it's always available and it's necessary to produce a
working XML file.

In addition, I added a docstring for all the functions/classes and
implemented tests for the following methods:

- `get_graph_desc()`
- `PortDesc.as_xml_element()`
- `NodeDesc.as_xml_element()`
- `EdgeDesc.as_xml_element()`
 

### Reason for changes

<!--- Why should the change be applied -->
It was not possible to open XML files produced by the function
`save_for_netron` due to the error: `Error loading OpenVINO model.
Unsupported precision 'undefined'`

### Related tickets

N/A

### Tests

To validate the accuracy of the modifications, I created Netron XML
files from multiple ONNX models. The visualization of these models in
Netron was successful, confirming the effectiveness of the changes.
  • Loading branch information
DaniAffCH authored Mar 13, 2024
1 parent b2a3ccf commit 6f03560
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 7 deletions.
100 changes: 93 additions & 7 deletions nncf/experimental/common/graph/netron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nncf.common.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import Dtype


class Tags:
Expand All @@ -31,17 +32,30 @@ class Tags:


class PortDesc:
def __init__(self, port_id: str, shape: Optional[List[int]] = None, precision: str = None):
"""
Represents a port description in the computational graph.
"""

def __init__(self, port_id: str, precision: str, shape: Optional[List[int]] = None):
"""
:param port_id: The identifier of the port.
:param precision: Precision of the tensor corresponding to the port, either "fp32" or "i32".
:param shape: The shape of the tensor. Defaults to an empty list if not provided.
"""
self.port_id = port_id
if shape is None:
shape = []
self.shape = shape
self.precision = precision

def as_xml_element(self) -> ET.Element:
port = ET.Element(Tags.PORT, id=self.port_id)
if self.precision:
port.set("precision", self.precision)
"""
Converts the PortDesc object into an XML element.
:return: The Element representing the port in XML
"""
port = ET.Element(Tags.PORT, id=self.port_id, precision=self.precision)

for i in self.shape:
dim = ET.Element(Tags.DIM)
dim.text = str(i)
Expand All @@ -50,6 +64,10 @@ def as_xml_element(self) -> ET.Element:


class NodeDesc:
"""
Represents a node description in the computational graph.
"""

def __init__(
self,
node_id: str,
Expand All @@ -59,6 +77,14 @@ def __init__(
inputs: Optional[List[PortDesc]] = None,
outputs: Optional[List[PortDesc]] = None,
):
"""
:param node_id: The identifier of the node.
:param name: The name of the node.
:param type: The type of the node.
:param attrs: Additional attributes of the node. Default empty dictionary.
:param inputs: List of input ports for the node.
:param outputs: List of output ports for the node.
"""
self.node_id = node_id
self.name = name
self.type = node_type
Expand All @@ -69,6 +95,11 @@ def __init__(
self.outputs = outputs

def as_xml_element(self) -> ET.Element:
"""
Converts the NodeDesc object into an XML element.
:return: The Element representing the node in XML
"""
node = ET.Element(Tags.NODE, id=self.node_id, name=self.name, type=self.type)
ET.SubElement(node, Tags.DATA, self.attrs)

Expand All @@ -86,13 +117,28 @@ def as_xml_element(self) -> ET.Element:


class EdgeDesc:
"""
Represents an edge description in the computational graph.
"""

def __init__(self, from_node: str, from_port: str, to_node: str, to_port: str):
"""
:param from_node: The identifier of the source node.
:param from_port: The identifier of the output port of the source node.
:param to_node: The identifier of the target node.
:param to_port: The identifier of the input port of the target node.
"""
self.from_node = from_node
self.from_port = from_port
self.to_node = to_node
self.to_port = to_port

def as_xml_element(self) -> ET.Element:
"""
Converts the EdgeDesc object into an XML element.
:return: The Element representing the edge in XML
"""
attrs = {
"from-layer": self.from_node,
"from-port": self.from_port,
Expand All @@ -106,10 +152,33 @@ def as_xml_element(self) -> ET.Element:
GET_ATTRIBUTES_FN_TYPE = Callable[[NNCFNode], Dict[str, str]]


# TODO(andrey-churkin): Add support for `PortDesc.precision` param.
def convert_nncf_dtype_to_ov_dtype(dtype: Dtype) -> str:
"""
Converts a nncf dtype to an openvino dtype string.
:param dtype: The data type to be converted.
:return: The openvino dtype string corresponding to the given data type.
"""

dummy_precision_map: Dict[Dtype, str] = {Dtype.INTEGER: "i32", Dtype.FLOAT: "f32"}

return dummy_precision_map[dtype]


def get_graph_desc(
graph: NNCFGraph, include_fq_params: bool = False, get_attributes_fn: Optional[GET_ATTRIBUTES_FN_TYPE] = None
) -> Tuple[List[NodeDesc], List[EdgeDesc]]:
"""
Retrieves descriptions of nodes and edges from an NNCFGraph.
:param graph: The NNCFGraph instance to extract descriptions from.
:param include_fq_params: Whether to include FakeQuantize parameters in the description.
:param get_attributes_fn: A function to retrieve additional attributes for nodes.
Defaults to a function returning {"metatype": str(x.metatype.name)}.
:return: A tuple containing lists of NodeDesc and EdgeDesc objects
representing the nodes and edges of the NNCFGraph.
"""

if get_attributes_fn is None:
get_attributes_fn = lambda x: {
"metatype": str(x.metatype.name),
Expand Down Expand Up @@ -139,14 +208,20 @@ def get_graph_desc(
for edge in graph.get_input_edges(node):
if not include_fq_params and node.node_type == "FakeQuantize" and edge.input_port_id != 0:
continue

inputs.append(PortDesc(port_id=str(edge.input_port_id), shape=edge.tensor_shape))
inputs.append(
PortDesc(
port_id=str(edge.input_port_id),
precision=convert_nncf_dtype_to_ov_dtype(edge.dtype),
shape=edge.tensor_shape,
)
)

outputs = []
for edge in graph.get_output_edges(node):
outputs.append(
PortDesc(
port_id=str(edge.output_port_id),
precision=convert_nncf_dtype_to_ov_dtype(edge.dtype),
shape=edge.tensor_shape,
)
)
Expand All @@ -172,6 +247,17 @@ def save_for_netron(
include_fq_params: bool = False,
get_attributes_fn: Optional[GET_ATTRIBUTES_FN_TYPE] = None,
):
"""
Save the NNCFGraph information in an XML file suitable for visualization with Netron.
:param graph: The NNCFGraph instance to visualize.
:param save_path: The path to save the Netron-compatible file.
:param graph_name: The name of the graph. Defaults to "Graph".
:param include_fq_params: Whether to include FakeQuantize parameters in the visualization.
:param get_attributes_fn: A function to retrieve additional attributes for nodes.
Defaults to a function returning {"metatype": str(x.metatype.name)}.
"""

node_descs, edge_descs = get_graph_desc(graph, include_fq_params, get_attributes_fn)

net = ET.Element(Tags.NET, name=graph_name)
Expand Down
125 changes: 125 additions & 0 deletions tests/common/experimental/test_netron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import xml.etree.ElementTree as ET # nosec
from dataclasses import dataclass
from typing import Optional

import pytest

from nncf.experimental.common.graph.netron import GET_ATTRIBUTES_FN_TYPE
from nncf.experimental.common.graph.netron import EdgeDesc
from nncf.experimental.common.graph.netron import NodeDesc
from nncf.experimental.common.graph.netron import PortDesc
from nncf.experimental.common.graph.netron import Tags
from nncf.experimental.common.graph.netron import convert_nncf_dtype_to_ov_dtype
from nncf.experimental.common.graph.netron import get_graph_desc
from tests.common.quantization.mock_graphs import get_two_branch_mock_model_graph


@dataclass
class GraphDescTestCase:
include_fq_params: Optional[bool]
get_attributes_fn: GET_ATTRIBUTES_FN_TYPE


GRAPH_DESC_TEST_CASES = [
GraphDescTestCase(include_fq_params=False, get_attributes_fn=None),
GraphDescTestCase(include_fq_params=True, get_attributes_fn=None),
GraphDescTestCase(include_fq_params=True, get_attributes_fn=lambda x: {"name": x.node_name, "type": x.node_type}),
]


@pytest.mark.parametrize(
"graph_desc_test_case",
GRAPH_DESC_TEST_CASES,
)
def test_get_graph_desc(graph_desc_test_case: GraphDescTestCase):
include_fq_params = graph_desc_test_case.include_fq_params
get_attributes_fn = graph_desc_test_case.get_attributes_fn

nncf_graph = get_two_branch_mock_model_graph()

edges = list(nncf_graph.get_all_edges())
nodes = list(nncf_graph.get_all_nodes())

node_desc_list, edges_desc_list = get_graph_desc(nncf_graph, include_fq_params, get_attributes_fn)

assert all(isinstance(node_desc, NodeDesc) for node_desc in node_desc_list)
assert all(isinstance(edge_desc, EdgeDesc) for edge_desc in edges_desc_list)

assert len(node_desc_list) == len(nodes)
assert len(edges_desc_list) == len(edges)

if get_attributes_fn is not None:
assert all([node_desc.attrs == get_attributes_fn(node) for node, node_desc in zip(nodes, node_desc_list)])


def test_edge_desc():
nncf_graph = get_two_branch_mock_model_graph()

for edge in nncf_graph.get_all_edges():
edgeDesc = EdgeDesc(
from_node=str(edge.from_node.node_id),
from_port=str(edge.output_port_id),
to_node=str(edge.to_node.node_id),
to_port=str(edge.input_port_id),
)

xmlElement = edgeDesc.as_xml_element()

assert isinstance(xmlElement, ET.Element)
assert xmlElement.tag == Tags.EDGE
assert xmlElement.attrib["from-layer"] == str(edge.from_node.node_id)
assert xmlElement.attrib["from-port"] == str(edge.output_port_id)
assert xmlElement.attrib["to-layer"] == str(edge.to_node.node_id)
assert xmlElement.attrib["to-port"] == str(edge.input_port_id)


def test_node_desc():
nncf_graph = get_two_branch_mock_model_graph()

for node in nncf_graph.get_all_nodes():
nodeDesc = NodeDesc(
node_id=str(node.node_id),
name=node.node_name,
node_type=node.node_type.title(),
)

xmlElement = nodeDesc.as_xml_element()

assert isinstance(xmlElement, ET.Element)
assert xmlElement.tag == Tags.NODE
assert xmlElement.attrib["id"] == str(node.node_id)
assert xmlElement.attrib["name"] == node.node_name
assert xmlElement.attrib["type"] == node.node_type.title()
assert all([child.tag == Tags.DATA for child in xmlElement])


def test_port_desc():
nncf_graph = get_two_branch_mock_model_graph()

for edge in nncf_graph.get_all_edges():
portDesc = PortDesc(
port_id=str(edge.input_port_id),
precision=convert_nncf_dtype_to_ov_dtype(edge.dtype),
shape=edge.tensor_shape,
)

xmlElement = portDesc.as_xml_element()

assert xmlElement.tag == Tags.PORT
assert xmlElement.attrib["id"] == str(edge.input_port_id)
assert xmlElement.attrib["precision"] == convert_nncf_dtype_to_ov_dtype(edge.dtype)
assert all([child.tag == Tags.DIM for child in xmlElement])
assert all(
[str(edge_shape) == port_shape.text for edge_shape, port_shape in zip(edge.tensor_shape, xmlElement)]
)

0 comments on commit 6f03560

Please sign in to comment.