Skip to content

Commit

Permalink
[devtools/visualization] Add visualize_graph
Browse files Browse the repository at this point in the history
When working with passes, you might have access to
a modified graph_module rather than an exported_program.
visualize_graph allows visualization of this graph_module
by combining  the modified graph_module with an exported_program.
Note that the graph_module can't be set directly, a new
exported_program needs to be constructed.

Additionally, we disable the operator validation for the
newly constructed ExportedProgram. This is ok since
it is only used for visualization.

Signed-off-by: Erik Lundell <[email protected]>
Change-Id: I4fad809bf094a1ec70e25534cc0858f9d8d3d225
  • Loading branch information
Erik-Lundell committed Jan 17, 2025
1 parent fc6b83e commit 1205e3a
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 14 deletions.
1 change: 1 addition & 0 deletions devtools/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
ModelExplorerServer,
SingletonModelExplorerServer,
visualize,
visualize_graph,
)
32 changes: 31 additions & 1 deletion devtools/visualization/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

import subprocess
import time
from typing import Any, Callable, Type

from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
from executorch.exir.program._program import _update_exported_program_graph_module
from torch._export.verifier import Verifier
from torch.export.exported_program import ExportedProgram
from torch.fx import GraphModule

try:
from model_explorer import config, consts, visualize_from_config # type: ignore
Expand All @@ -27,7 +31,7 @@ class SingletonModelExplorerServer:

server: None | subprocess.Popen = None
num_open: int = 0
wait_after_start = 2.0
wait_after_start = 3.0

def __init__(self, open_in_browser: bool = True, port: int | None = None):
if SingletonModelExplorerServer.server is None:
Expand Down Expand Up @@ -124,3 +128,29 @@ def visualize(
no_open_in_browser=no_open_in_browser,
**kwargs,
)


def visualize_graph(
graph_module: GraphModule,
exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
reuse_server: bool = True,
no_open_in_browser: bool = False,
**kwargs,
):
"""Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing.
Also disables validating operators to allow visualizing graphs containing custom ops.
A typical example is after running passes, which returns a graph_module rather than an ExportedProgram.
"""

class _any_op(Verifier):
dialect = "ANY_OP"

def allowed_op_types(self) -> tuple[Type[Any], ...]:
return (Callable,) # type: ignore

exported_program = _get_exported_program(exported_program)
exported_program = _update_exported_program_graph_module(
exported_program, graph_module, override_verifiers=[_any_op]
)
visualize(exported_program, reuse_server, no_open_in_browser, **kwargs)
17 changes: 16 additions & 1 deletion devtools/visualization/visualization_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@

import pytest
import torch
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.xnnpack.test.tester import Tester

from executorch.devtools.visualization import (
ModelExplorerServer,
SingletonModelExplorerServer,
visualization_utils,
visualize,
visualize_graph,
)
from executorch.exir import ExportedProgram
from executorch.exir import ExportedProgram, to_edge_transform_and_lower

try:
from model_explorer.config import ModelExplorerConfig # type: ignore
Expand Down Expand Up @@ -145,6 +147,17 @@ def test_visualize_to_executorch(server):
)


def test_visualize_graph(server):
with server():
model = Linear(20, 30)
exported_program = torch.export.export(model, model.get_inputs())
exported_program = to_edge_transform_and_lower(
exported_program
).exported_program()
modified_gm = DecomposeLinearPass()(exported_program.graph_module).graph_module
visualize_graph(modified_gm, exported_program)


if __name__ == "__main__":
"""A test to run locally to make sure that the web browser opens up
automatically as intended.
Expand All @@ -158,3 +171,5 @@ def test_visualize_to_executorch(server):
test_visualize_to_edge(SingletonModelExplorerServer)
test_visualize_partition(SingletonModelExplorerServer)
test_visualize_to_executorch(SingletonModelExplorerServer)
test_visualize_graph(SingletonModelExplorerServer)
time.sleep(3.0)
34 changes: 22 additions & 12 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -9,7 +10,7 @@
import copy
import io
import logging
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Type, Union

import torch
import torch._export
Expand Down Expand Up @@ -60,6 +61,7 @@
get_aten_verifier,
)
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
from torch._export.verifier import Verifier
from torch.export import ExportedProgram
from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
Expand Down Expand Up @@ -207,21 +209,29 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
if transformed_gm is self.graph_module and not res.modified:
return self

return _update_exported_program_graph_module(self, transformed_gm)


def _update_exported_program_graph_module(
exported_program: ExportedProgram,
gm: torch.fx.GraphModule,
override_verifiers: None | list[Type[Verifier]] = None,
) -> "ExportedProgram":
transformed_ep = ExportedProgram(
root=transformed_gm,
graph=transformed_gm.graph,
root=gm,
graph=gm.graph,
graph_signature=_get_updated_graph_signature(
self.graph_signature, transformed_gm
exported_program.graph_signature, gm
),
state_dict=self.state_dict,
range_constraints=_get_updated_range_constraints(transformed_gm),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
constants=self.constants,
verifiers=[self.verifier],
state_dict=exported_program.state_dict,
range_constraints=_get_updated_range_constraints(gm),
module_call_graph=copy.deepcopy(exported_program._module_call_graph),
example_inputs=exported_program.example_inputs,
constants=exported_program.constants,
verifiers=override_verifiers or [exported_program.verifier],
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
transformed_ep.graph_module.meta.update(res.graph_module.meta)
transformed_ep.graph_module.meta.update(exported_program.graph_module.meta)
transformed_ep.graph_module.meta.update(gm.meta)
return transformed_ep


Expand Down

0 comments on commit 1205e3a

Please sign in to comment.