diff --git a/backends/xnnpack/serialization/xnnpack_graph_serialize.py b/backends/xnnpack/serialization/xnnpack_graph_serialize.py index 160c926780..0fbd0ddc5e 100644 --- a/backends/xnnpack/serialization/xnnpack_graph_serialize.py +++ b/backends/xnnpack/serialization/xnnpack_graph_serialize.py @@ -5,11 +5,13 @@ # LICENSE file in the root directory of this source tree. import json + +import logging import os import tempfile from dataclasses import dataclass, fields, is_dataclass -from typing import ClassVar, Literal +from typing import ClassVar, Literal, Optional import pkg_resources from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph @@ -17,6 +19,9 @@ from executorch.exir._serialize._flatbuffer import _flatc_compile +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + # Byte order of numbers written to program headers. Always little-endian # regardless of the host system, since all commonly-used modern CPUs are little # endian. @@ -273,19 +278,42 @@ def _pad_to(data: bytes, length: int) -> bytes: return data -def pretty_print_xnngraph(xnnpack_graph_json: str): +def pretty_print_xnngraph(xnnpack_graph_json: str, filename: Optional[str] = None): """ - Pretty print the XNNGraph + Pretty print the XNNGraph, optionally writing to a file if filename is provided """ - from pprint import pprint + from pprint import pformat d = json.loads(xnnpack_graph_json) - pprint(d) + pstr = pformat(d, indent=2, compact=True).replace("'", '"') + if filename: + with open(filename, "w") as f: + if filename.endswith(".json"): + pstr = pstr.replace("None", "null") + f.write(pstr) + else: # dump to stdout + print("XNNGraph:") + print(pstr) + print("End of XNNGraph") + + +# TODO: Replace this with an actual delegate id +_delegate_instance_id = 0 def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes: + global _delegate_instance_id sanity_check_xnngraph_dataclass(xnnpack_graph) xnnpack_graph_json = json.dumps(xnnpack_graph, cls=_DataclassEncoder) + + # Log the XNNGraph if debugging + if logger.getEffectiveLevel() == logging.DEBUG: + filename: str = f"./xnnpack_delegate_graph_{_delegate_instance_id}.json" + logger.debug(f"Writing XNNGraph to {filename}") + pretty_print_xnngraph(xnnpack_graph_json, filename) + + _delegate_instance_id += 1 + with tempfile.TemporaryDirectory() as d: schema_path = os.path.join(d, "schema.fbs") with open(schema_path, "wb") as schema_file: