Skip to content

Commit

Permalink
Merge branch 'main' into export-D68109702
Browse files Browse the repository at this point in the history
  • Loading branch information
zonglinpeng authored Jan 16, 2025
2 parents 3a1a9f2 + ee00caa commit 229dd31
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions backends/xnnpack/serialization/xnnpack_graph_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
# LICENSE file in the root directory of this source tree.

import json

import logging
import os
import tempfile

from dataclasses import dataclass, fields, is_dataclass
from typing import ClassVar, Literal
from typing import ClassVar, Literal, Optional

import pkg_resources
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import XNNGraph
from executorch.exir._serialize._dataclass import _DataclassEncoder

from executorch.exir._serialize._flatbuffer import _flatc_compile

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

# Byte order of numbers written to program headers. Always little-endian
# regardless of the host system, since all commonly-used modern CPUs are little
# endian.
Expand Down Expand Up @@ -273,19 +278,42 @@ def _pad_to(data: bytes, length: int) -> bytes:
return data


def pretty_print_xnngraph(xnnpack_graph_json: str):
def pretty_print_xnngraph(xnnpack_graph_json: str, filename: Optional[str] = None):
"""
Pretty print the XNNGraph
Pretty print the XNNGraph, optionally writing to a file if filename is provided
"""
from pprint import pprint
from pprint import pformat

d = json.loads(xnnpack_graph_json)
pprint(d)
pstr = pformat(d, indent=2, compact=True).replace("'", '"')
if filename:
with open(filename, "w") as f:
if filename.endswith(".json"):
pstr = pstr.replace("None", "null")
f.write(pstr)
else: # dump to stdout
print("XNNGraph:")
print(pstr)
print("End of XNNGraph")


# TODO: Replace this with an actual delegate id
_delegate_instance_id = 0


def convert_to_flatbuffer(xnnpack_graph: XNNGraph) -> bytes:
global _delegate_instance_id
sanity_check_xnngraph_dataclass(xnnpack_graph)
xnnpack_graph_json = json.dumps(xnnpack_graph, cls=_DataclassEncoder)

# Log the XNNGraph if debugging
if logger.getEffectiveLevel() == logging.DEBUG:
filename: str = f"./xnnpack_delegate_graph_{_delegate_instance_id}.json"
logger.debug(f"Writing XNNGraph to {filename}")
pretty_print_xnngraph(xnnpack_graph_json, filename)

_delegate_instance_id += 1

with tempfile.TemporaryDirectory() as d:
schema_path = os.path.join(d, "schema.fbs")
with open(schema_path, "wb") as schema_file:
Expand Down

0 comments on commit 229dd31

Please sign in to comment.