From 409d280b3f5387b517ca04ba0ee5be9dbdb7ff7e Mon Sep 17 00:00:00 2001 From: "Sam Larsen (Meta Employee)" Date: Sun, 15 Dec 2024 17:57:39 -0800 Subject: [PATCH] Log runtime autotuning timing to scuba (#141919) Summary: See test plan in internal diff [D66679369](https://our.internmc.facebook.com/intern/diff/D66679369) X-link: https://github.com/pytorch/pytorch/pull/141919 Approved by: https://github.com/jamesjwu, https://github.com/ezyang Differential Revision: D67218561 Pulled By: masnesral --- .../dynamo/dynamobench/_dynamo/utils.py | 91 +++++++++++++++---- 1 file changed, 73 insertions(+), 18 deletions(-) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 138bd3535..677acf0b6 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -73,8 +73,8 @@ _push_on_torch_function_stack, ) from torch._dispatch.python import enable_python_dispatcher -from torch._dynamo.metrics_context import MetricsContext -from torch._guards import Source, TracingContext +from torch._dynamo.metrics_context import MetricsContext, RuntimeMetricsContext +from torch._guards import CompileId, Source, TracingContext from torch._subclasses.meta_utils import is_sparse_compressed from torch._utils_internal import ( log_chromium_event_internal, @@ -288,12 +288,17 @@ def print_time_report() -> None: # ... # _METRICS_CONTEXT: MetricsContext +_RUNTIME_METRICS_CONTEXT: RuntimeMetricsContext def get_metrics_context() -> MetricsContext: return _METRICS_CONTEXT +def get_runtime_metrics_context() -> RuntimeMetricsContext: + return _RUNTIME_METRICS_CONTEXT + + @contextmanager def dynamo_timed( key: str, @@ -302,16 +307,20 @@ def dynamo_timed( log_pt2_compile_event: bool = False, metadata: Optional[Dict[str, object]] = None, dynamo_compile_column_us: Optional[str] = None, + dynamo_compile_runtime_column_us: Optional[str] = None, + compile_id: Optional[CompileId] = None, + is_forward: Optional[bool] = None, log_waitcounter: bool = False, ) -> Generator[Any, None, None]: """ dynamo_timed is a context manager By wrapping a function in dynamo_timed, we can get a few things: - 1) Log timings to pt2_compile_events. - 2) Log timings to CompilationMetrics (dynamo_compile). - 3) Chromium events. - 4) Storing a record in compilation_time_metrics + 1) Optionally log timings to pt2_compile_events. + 2) Optionally log timings to CompilationMetrics (dynamo_compile). + 3) Optionally log chromium events. + 4) Optionally increment a WaitCounter. + 5) Store a record in compilation_time_metrics For example: def _foo(...): @@ -336,12 +345,23 @@ def _foo(...): - dynamo_compile_column_us: If provided, updates the specified CompilationMetrics field to be logged to dyname_compile column. We expect all columns to be _us; therefore, the field name must end with "_us". + - dynamo_compile_runtime_column_us: Like 'dynamo_compile_column_us', but should + be used for those columns captured outside of a compile context, e.g., + runtime autotuning. + - compile_id: In the typical case, this parameter should not be needed. Use to + supply the compile_id for those cases where we want to log a compile_id where + it's not naturally available, e.g., for runtime autotuning. + - is_forward: Optionally set an is_forward field for those logging destinations + that support it. - log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}" """ # We're standardizing on microseconds for dynamo_compile timings. if dynamo_compile_column_us is not None: assert dynamo_compile_column_us.endswith("_us") + # Only one of these should be set. + assert dynamo_compile_column_us is None or dynamo_compile_runtime_column_us is None + if phase_name: event_name = phase_name fn_name = key @@ -357,11 +377,13 @@ def _foo(...): event_metadata.update(metadata) if fn_name: event_metadata.update({"fn_name": fn_name}) + if is_forward is not None: + event_metadata.update({"is_backward": not is_forward}) chromium_log: ChromiumEventLogger = get_chromium_event_logger() start_ns = time.time_ns() chromium_log.log_event_start( - event_name, start_ns, event_metadata, log_pt2_compile_event + event_name, start_ns, event_metadata, log_pt2_compile_event, compile_id ) try: @@ -376,7 +398,7 @@ def _foo(...): time_spent_ns = end_ns - start_ns compilation_time_metrics[key].append(time_spent_ns / 1e9) chromium_log.log_event_end( - event_name, end_ns, {}, start_ns, log_pt2_compile_event + event_name, end_ns, {}, start_ns, log_pt2_compile_event, compile_id ) if dynamo_compile_column_us: metrics_context = get_metrics_context() @@ -391,6 +413,18 @@ def _foo(...): # this way? cumulative_time_spent_ns[event_name] += time_spent_ns + if dynamo_compile_runtime_column_us: + get_runtime_metrics_context().increment( + dynamo_compile_runtime_column_us, + time_spent_ns // 1000, + extra={ + "compile_id": compile_id, + "is_runtime": True, + "is_forward": is_forward, + }, + ) + cumulative_time_spent_ns[event_name] += time_spent_ns + @overload def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: @@ -858,7 +892,7 @@ class CompilationMetrics: inductor_code_gen_cumulative_compile_time_us: Optional[int] = None triton_compile_time_us: Optional[int] = None runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument - runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument + runtime_triton_autotune_time_us: Optional[int] = None dynamo_compile_time_before_restart_us: Optional[int] = None cuda_synchronize_time_us: Optional[int] = None # TODO: instrument distributed_ephemeral_timeout_us: Optional[int] = None @@ -882,6 +916,7 @@ class CompilationMetrics: triton_version: Optional[str] = None feature_usage: Optional[dict[str, bool]] = None compile_time_autotune_time_us: Optional[int] = None + is_runtime: Optional[bool] = False DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -1022,8 +1057,14 @@ def safe_str(item: Any) -> str: inductor_fx_remote_cache_backend_type = None remote_cache_version = None + # Populate the compile_id from the metrics context if it's set. Otherwise + # look for it in the compile context. + compile_id = metrics.get("compile_id") + if not compile_id: + compile_id = torch._guards.CompileContext.current_compile_id() + common_metrics = { - "compile_id": str(torch._guards.CompileContext.current_compile_id()), + "compile_id": str(compile_id) if compile_id else None, "start_time_us": start_time_ns // 1000, "end_time_us": end_time_ns // 1000, "duration_us": (end_time_ns - start_time_ns) // 1000, @@ -1066,10 +1107,12 @@ def safe_str(item: Any) -> str: ) _compilation_metrics.append(compilation_metrics) - if compilation_metrics.is_forward: - name = "compilation_metrics" - else: - name = "bwd_compilation_metrics" + name = "compilation_metrics" + if compilation_metrics.is_forward is False: + name = "bwd_" + name + if compilation_metrics.is_runtime is True: + name = name + "_runtime" + torch._logging.trace_structured( name, lambda: { @@ -1081,6 +1124,10 @@ def safe_str(item: Any) -> str: # without making it inconsistent with compilation metrics itself, so # we ignore the (hopefully small) time spent logging compilation metrics record_logging_overhead=False, + # These may be runtime logs, e.g., runtime autotunning, so we provide + # the CompileId from the compilation metrics in case it's not available + # in the current trace. + compile_id=compile_id, ) # If there's a chromium event in flight, add the CompilationMetrics to it. @@ -1093,6 +1140,7 @@ def safe_str(item: Any) -> str: # record_compilation_metrics is called by the singleton MetricsContext exit handler. _METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics) +_RUNTIME_METRICS_CONTEXT = RuntimeMetricsContext(on_exit=record_compilation_metrics) def set_compilation_metrics_limit(new_size: int) -> None: @@ -1196,15 +1244,18 @@ def log_event_start( time_ns: int, metadata: Dict[str, Any], log_pt2_compile_event: bool = False, + compile_id: Optional[CompileId] = None, ) -> None: """ Logs the start of a single event. :param str event_name Name of event to appear in trace :param time_ns Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event + :param log_pt_compile_event: If True, log to pt2_compile_events + :param compile_id: Explicit compile_id (rather than using the current context) """ - compile_id = str(torch._guards.CompileContext.current_compile_id()) - metadata["compile_id"] = compile_id + compile_id = compile_id or torch._guards.CompileContext.current_compile_id() + metadata["compile_id"] = str(compile_id) self._log_timed_event( event_name, time_ns, @@ -1234,6 +1285,7 @@ def log_event_end( metadata: Dict[str, Any], start_time_ns: int, log_pt2_compile_event: bool, + compile_id: Optional[CompileId] = None, ) -> None: """ Logs the end of a single event. This function should only be @@ -1241,9 +1293,12 @@ def log_event_end( :param event_name: Name of event to appear in trace :param time_ns: Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event + :param start_time_ns: The start time timestamp in nanoseconds + :param log_pt_compile_event: If True, log to pt2_compile_events + :param compile_id: Explicit compile_id (rather than using the current context) """ - compile_id = str(torch._guards.CompileContext.current_compile_id()) - metadata["compile_id"] = compile_id + compile_id = compile_id or torch._guards.CompileContext.current_compile_id() + metadata["compile_id"] = str(compile_id) # Grab metadata collected during event span all_event_data = self.get_event_data()