Skip to content

Commit

Permalink
Log runtime autotuning timing to scuba (#141919)
Browse files Browse the repository at this point in the history
Summary:
See test plan in internal diff [D66679369](https://our.internmc.facebook.com/intern/diff/D66679369)

X-link: pytorch/pytorch#141919
Approved by: https://github.com/jamesjwu, https://github.com/ezyang

Differential Revision: D67218561

Pulled By: masnesral
  • Loading branch information
masnesral authored and facebook-github-bot committed Dec 16, 2024
1 parent 285fb28 commit 409d280
Showing 1 changed file with 73 additions and 18 deletions.
91 changes: 73 additions & 18 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
_push_on_torch_function_stack,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.metrics_context import MetricsContext
from torch._guards import Source, TracingContext
from torch._dynamo.metrics_context import MetricsContext, RuntimeMetricsContext
from torch._guards import CompileId, Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import (
log_chromium_event_internal,
Expand Down Expand Up @@ -288,12 +288,17 @@ def print_time_report() -> None:
# ...
#
_METRICS_CONTEXT: MetricsContext
_RUNTIME_METRICS_CONTEXT: RuntimeMetricsContext


def get_metrics_context() -> MetricsContext:
return _METRICS_CONTEXT


def get_runtime_metrics_context() -> RuntimeMetricsContext:
return _RUNTIME_METRICS_CONTEXT


@contextmanager
def dynamo_timed(
key: str,
Expand All @@ -302,16 +307,20 @@ def dynamo_timed(
log_pt2_compile_event: bool = False,
metadata: Optional[Dict[str, object]] = None,
dynamo_compile_column_us: Optional[str] = None,
dynamo_compile_runtime_column_us: Optional[str] = None,
compile_id: Optional[CompileId] = None,
is_forward: Optional[bool] = None,
log_waitcounter: bool = False,
) -> Generator[Any, None, None]:
"""
dynamo_timed is a context manager
By wrapping a function in dynamo_timed, we can get a few things:
1) Log timings to pt2_compile_events.
2) Log timings to CompilationMetrics (dynamo_compile).
3) Chromium events.
4) Storing a record in compilation_time_metrics
1) Optionally log timings to pt2_compile_events.
2) Optionally log timings to CompilationMetrics (dynamo_compile).
3) Optionally log chromium events.
4) Optionally increment a WaitCounter.
5) Store a record in compilation_time_metrics
For example:
def _foo(...):
Expand All @@ -336,12 +345,23 @@ def _foo(...):
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
field to be logged to dyname_compile column. We expect all columns to be _us;
therefore, the field name must end with "_us".
- dynamo_compile_runtime_column_us: Like 'dynamo_compile_column_us', but should
be used for those columns captured outside of a compile context, e.g.,
runtime autotuning.
- compile_id: In the typical case, this parameter should not be needed. Use to
supply the compile_id for those cases where we want to log a compile_id where
it's not naturally available, e.g., for runtime autotuning.
- is_forward: Optionally set an is_forward field for those logging destinations
that support it.
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
"""
# We're standardizing on microseconds for dynamo_compile timings.
if dynamo_compile_column_us is not None:
assert dynamo_compile_column_us.endswith("_us")

# Only one of these should be set.
assert dynamo_compile_column_us is None or dynamo_compile_runtime_column_us is None

if phase_name:
event_name = phase_name
fn_name = key
Expand All @@ -357,11 +377,13 @@ def _foo(...):
event_metadata.update(metadata)
if fn_name:
event_metadata.update({"fn_name": fn_name})
if is_forward is not None:
event_metadata.update({"is_backward": not is_forward})

chromium_log: ChromiumEventLogger = get_chromium_event_logger()
start_ns = time.time_ns()
chromium_log.log_event_start(
event_name, start_ns, event_metadata, log_pt2_compile_event
event_name, start_ns, event_metadata, log_pt2_compile_event, compile_id
)

try:
Expand All @@ -376,7 +398,7 @@ def _foo(...):
time_spent_ns = end_ns - start_ns
compilation_time_metrics[key].append(time_spent_ns / 1e9)
chromium_log.log_event_end(
event_name, end_ns, {}, start_ns, log_pt2_compile_event
event_name, end_ns, {}, start_ns, log_pt2_compile_event, compile_id
)
if dynamo_compile_column_us:
metrics_context = get_metrics_context()
Expand All @@ -391,6 +413,18 @@ def _foo(...):
# this way?
cumulative_time_spent_ns[event_name] += time_spent_ns

if dynamo_compile_runtime_column_us:
get_runtime_metrics_context().increment(
dynamo_compile_runtime_column_us,
time_spent_ns // 1000,
extra={
"compile_id": compile_id,
"is_runtime": True,
"is_forward": is_forward,
},
)
cumulative_time_spent_ns[event_name] += time_spent_ns


@overload
def compile_times(repr: Literal["str"], aggregate: bool = False) -> str:
Expand Down Expand Up @@ -858,7 +892,7 @@ class CompilationMetrics:
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
triton_compile_time_us: Optional[int] = None
runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument
runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument
runtime_triton_autotune_time_us: Optional[int] = None
dynamo_compile_time_before_restart_us: Optional[int] = None
cuda_synchronize_time_us: Optional[int] = None # TODO: instrument
distributed_ephemeral_timeout_us: Optional[int] = None
Expand All @@ -882,6 +916,7 @@ class CompilationMetrics:
triton_version: Optional[str] = None
feature_usage: Optional[dict[str, bool]] = None
compile_time_autotune_time_us: Optional[int] = None
is_runtime: Optional[bool] = False


DEFAULT_COMPILATION_METRICS_LIMIT = 64
Expand Down Expand Up @@ -1022,8 +1057,14 @@ def safe_str(item: Any) -> str:
inductor_fx_remote_cache_backend_type = None
remote_cache_version = None

# Populate the compile_id from the metrics context if it's set. Otherwise
# look for it in the compile context.
compile_id = metrics.get("compile_id")
if not compile_id:
compile_id = torch._guards.CompileContext.current_compile_id()

common_metrics = {
"compile_id": str(torch._guards.CompileContext.current_compile_id()),
"compile_id": str(compile_id) if compile_id else None,
"start_time_us": start_time_ns // 1000,
"end_time_us": end_time_ns // 1000,
"duration_us": (end_time_ns - start_time_ns) // 1000,
Expand Down Expand Up @@ -1066,10 +1107,12 @@ def safe_str(item: Any) -> str:
)
_compilation_metrics.append(compilation_metrics)

if compilation_metrics.is_forward:
name = "compilation_metrics"
else:
name = "bwd_compilation_metrics"
name = "compilation_metrics"
if compilation_metrics.is_forward is False:
name = "bwd_" + name
if compilation_metrics.is_runtime is True:
name = name + "_runtime"

torch._logging.trace_structured(
name,
lambda: {
Expand All @@ -1081,6 +1124,10 @@ def safe_str(item: Any) -> str:
# without making it inconsistent with compilation metrics itself, so
# we ignore the (hopefully small) time spent logging compilation metrics
record_logging_overhead=False,
# These may be runtime logs, e.g., runtime autotunning, so we provide
# the CompileId from the compilation metrics in case it's not available
# in the current trace.
compile_id=compile_id,
)

# If there's a chromium event in flight, add the CompilationMetrics to it.
Expand All @@ -1093,6 +1140,7 @@ def safe_str(item: Any) -> str:

# record_compilation_metrics is called by the singleton MetricsContext exit handler.
_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics)
_RUNTIME_METRICS_CONTEXT = RuntimeMetricsContext(on_exit=record_compilation_metrics)


def set_compilation_metrics_limit(new_size: int) -> None:
Expand Down Expand Up @@ -1196,15 +1244,18 @@ def log_event_start(
time_ns: int,
metadata: Dict[str, Any],
log_pt2_compile_event: bool = False,
compile_id: Optional[CompileId] = None,
) -> None:
"""
Logs the start of a single event.
:param str event_name Name of event to appear in trace
:param time_ns Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
:param log_pt_compile_event: If True, log to pt2_compile_events
:param compile_id: Explicit compile_id (rather than using the current context)
"""
compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id
compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
metadata["compile_id"] = str(compile_id)
self._log_timed_event(
event_name,
time_ns,
Expand Down Expand Up @@ -1234,16 +1285,20 @@ def log_event_end(
metadata: Dict[str, Any],
start_time_ns: int,
log_pt2_compile_event: bool,
compile_id: Optional[CompileId] = None,
) -> None:
"""
Logs the end of a single event. This function should only be
called after log_event_start with the same event_name.
:param event_name: Name of event to appear in trace
:param time_ns: Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
:param start_time_ns: The start time timestamp in nanoseconds
:param log_pt_compile_event: If True, log to pt2_compile_events
:param compile_id: Explicit compile_id (rather than using the current context)
"""
compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id
compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
metadata["compile_id"] = str(compile_id)

# Grab metadata collected during event span
all_event_data = self.get_event_data()
Expand Down

0 comments on commit 409d280

Please sign in to comment.