Skip to content

Commit

Permalink
timeline_visualizer: Resolve flake8 errors
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Feb 6, 2024
1 parent 2f9f48b commit f3769de
Showing 1 changed file with 32 additions and 73 deletions.
105 changes: 32 additions & 73 deletions timeline_visualizer/timeline_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class TID(IntEnum):

def get_logger(log_filename: str) -> logging.Logger:
formatter = logging.Formatter(
"%(levelname)s [%(asctime)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p")
"%(levelname)s [%(asctime)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p")

file_handler = FileHandler(log_filename, mode="w")
file_handler.setLevel(logging.DEBUG)
Expand All @@ -36,21 +36,17 @@ def get_logger(log_filename: str) -> logging.Logger:
return logger

def is_local_mem_node(node_name: str) -> bool:
if ("MEM_LOAD_NODE" in node_name)\
and ("LOCAL_MEMORY" in node_name):
if ("MEM_LOAD_NODE" in node_name) and ("LOCAL_MEMORY" in node_name):
return True
elif ("MEM_STORE_NODE" in node_name)\
and ("LOCAL_MEMORY" in node_name):
elif ("MEM_STORE_NODE" in node_name) and ("LOCAL_MEMORY" in node_name):
return True
else:
return False

def is_remote_mem_node(node_name: str) -> bool:
if ("MEM_LOAD_NODE" in node_name)\
and ("REMOTE_MEMORY" in node_name):
if ("MEM_LOAD_NODE" in node_name) and ("REMOTE_MEMORY" in node_name):
return True
elif ("MEM_STORE_NODE" in node_name)\
and ("REMOTE_MEMORY" in node_name):
elif ("MEM_STORE_NODE" in node_name) and ("REMOTE_MEMORY" in node_name):
return True
else:
return False
Expand Down Expand Up @@ -92,8 +88,8 @@ def parse_event(
node_id = int(cols[3].split("=")[1])
node_name = cols[4].split("=")[1]
return (trace_type, npu_id, curr_cycle, node_id, node_name)
except:
raise ValueError(f"Cannot parse the following event -- \"{line}\"")
except Exception as e:
raise ValueError(f"Cannot parse the following event -- \"{line}\": {e}")

def get_trace_events(
input_filename: str,
Expand All @@ -106,12 +102,10 @@ def get_trace_events(
with open(input_filename, "r") as f:
for line in f:
if ("issue" in line) or ("callback" in line):
(trace_type, npu_id, curr_cycle, node_id, node_name) =\
parse_event(line)
(trace_type, npu_id, curr_cycle, node_id, node_name) = parse_event(line)

if trace_type == "issue":
trace_dict[npu_id].update(
{node_id: [node_name, curr_cycle]})
trace_dict[npu_id].update({node_id: [node_name, curr_cycle]})
elif trace_type == "callback":
node_name = trace_dict[npu_id][node_id][0]
tid = get_tid(node_name)
Expand All @@ -120,24 +114,22 @@ def get_trace_events(
duration_in_cycles = curr_cycle - issued_cycle
duration_in_ms = duration_in_cycles / (npu_frequency * 1_000)

trace_events.append(
{
"pid": npu_id,
"tid": tid,
"ts": issued_ms,
"dur": duration_in_ms,
"ph": "X",
"name": node_name,
"args": {"ms": duration_in_ms}
})
trace_events.append({
"pid": npu_id,
"tid": tid,
"ts": issued_ms,
"dur": duration_in_ms,
"ph": "X",
"name": node_name,
"args": {"ms": duration_in_ms}
})

del trace_dict[npu_id][node_id]
else:
raise ValueError(f"Unsupported trace_type, {trace_type}")

return trace_events


def write_trace_events(
output_filename: str,
num_npus: int,
Expand All @@ -146,64 +138,31 @@ def write_trace_events(
output_dict = {
"meta_user": "aras",
"traceEvents": trace_events,
"meta_user": "aras",
"meta_cpu_count": num_npus
}
with open(output_filename, "w") as f:
json.dump(output_dict, f)

def main() -> None:
parser = argparse.ArgumentParser(
description="Timeline Visualizer"
)
parser.add_argument(
"--input_filename",
type=str,
default=None,
required=True,
help="Input timeline filename"
)
parser.add_argument(
"--output_filename",
type=str,
default=None,
required=True,
help="Output trace filename"
)
parser.add_argument(
"--num_npus",
type=int,
default=None,
required=True,
help="Number of NPUs in a system"
)
parser.add_argument(
"--npu_frequency",
type=int,
default=None,
required=True,
help="NPU frequency in MHz"
)
parser.add_argument(
"--log_filename",
type=str,
default="debug.log",
help="Log filename"
)
parser = argparse.ArgumentParser(description="Timeline Visualizer")
parser.add_argument("--input_filename", type=str, default=None, required=True,
help="Input timeline filename")
parser.add_argument("--output_filename", type=str, default=None, required=True,
help="Output trace filename")
parser.add_argument("--num_npus", type=int, default=None, required=True,
help="Number of NPUs in a system")
parser.add_argument("--npu_frequency", type=int, default=None, required=True,
help="NPU frequency in MHz")
parser.add_argument("--log_filename", type=str, default="debug.log",
help="Log filename")
args = parser.parse_args()

logger = get_logger(args.log_filename)
logger.debug(" ".join(sys.argv))

try:
trace_events = get_trace_events(
args.input_filename,
args.num_npus,
args.npu_frequency)
write_trace_events(
args.output_filename,
args.num_npus,
trace_events)
trace_events = get_trace_events(args.input_filename, args.num_npus, args.npu_frequency)
write_trace_events(args.output_filename, args.num_npus, trace_events)
except Exception as e:
logger.error(str(e))
sys.exit(1)
Expand Down

0 comments on commit f3769de

Please sign in to comment.