Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to support garbage collection after torch compilation #2559

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 74 additions & 18 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@
_push_on_torch_function_stack,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.metrics_context import MetricsContext
from torch._guards import Source, TracingContext
from torch._dynamo.metrics_context import MetricsContext, RuntimeMetricsContext
from torch._guards import CompileId, Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import (
log_chromium_event_internal,
Expand Down Expand Up @@ -288,12 +288,17 @@ def print_time_report() -> None:
# ...
#
_METRICS_CONTEXT: MetricsContext
_RUNTIME_METRICS_CONTEXT: RuntimeMetricsContext


def get_metrics_context() -> MetricsContext:
return _METRICS_CONTEXT


def get_runtime_metrics_context() -> RuntimeMetricsContext:
return _RUNTIME_METRICS_CONTEXT


@contextmanager
def dynamo_timed(
key: str,
Expand All @@ -302,16 +307,20 @@ def dynamo_timed(
log_pt2_compile_event: bool = False,
metadata: Optional[Dict[str, object]] = None,
dynamo_compile_column_us: Optional[str] = None,
dynamo_compile_runtime_column_us: Optional[str] = None,
compile_id: Optional[CompileId] = None,
is_forward: Optional[bool] = None,
log_waitcounter: bool = False,
) -> Generator[Any, None, None]:
"""
dynamo_timed is a context manager
By wrapping a function in dynamo_timed, we can get a few things:

1) Log timings to pt2_compile_events.
2) Log timings to CompilationMetrics (dynamo_compile).
3) Chromium events.
4) Storing a record in compilation_time_metrics
1) Optionally log timings to pt2_compile_events.
2) Optionally log timings to CompilationMetrics (dynamo_compile).
3) Optionally log chromium events.
4) Optionally increment a WaitCounter.
5) Store a record in compilation_time_metrics
For example:

def _foo(...):
Expand All @@ -336,12 +345,23 @@ def _foo(...):
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
field to be logged to dyname_compile column. We expect all columns to be _us;
therefore, the field name must end with "_us".
- dynamo_compile_runtime_column_us: Like 'dynamo_compile_column_us', but should
be used for those columns captured outside of a compile context, e.g.,
runtime autotuning.
- compile_id: In the typical case, this parameter should not be needed. Use to
supply the compile_id for those cases where we want to log a compile_id where
it's not naturally available, e.g., for runtime autotuning.
- is_forward: Optionally set an is_forward field for those logging destinations
that support it.
- log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
"""
# We're standardizing on microseconds for dynamo_compile timings.
if dynamo_compile_column_us is not None:
assert dynamo_compile_column_us.endswith("_us")

# Only one of these should be set.
assert dynamo_compile_column_us is None or dynamo_compile_runtime_column_us is None

if phase_name:
event_name = phase_name
fn_name = key
Expand All @@ -357,11 +377,13 @@ def _foo(...):
event_metadata.update(metadata)
if fn_name:
event_metadata.update({"fn_name": fn_name})
if is_forward is not None:
event_metadata.update({"is_backward": not is_forward})

chromium_log: ChromiumEventLogger = get_chromium_event_logger()
start_ns = time.time_ns()
chromium_log.log_event_start(
event_name, start_ns, event_metadata, log_pt2_compile_event
event_name, start_ns, event_metadata, log_pt2_compile_event, compile_id
)

try:
Expand All @@ -376,7 +398,7 @@ def _foo(...):
time_spent_ns = end_ns - start_ns
compilation_time_metrics[key].append(time_spent_ns / 1e9)
chromium_log.log_event_end(
event_name, end_ns, {}, start_ns, log_pt2_compile_event
event_name, end_ns, {}, start_ns, log_pt2_compile_event, compile_id
)
if dynamo_compile_column_us:
metrics_context = get_metrics_context()
Expand All @@ -391,6 +413,18 @@ def _foo(...):
# this way?
cumulative_time_spent_ns[event_name] += time_spent_ns

if dynamo_compile_runtime_column_us:
get_runtime_metrics_context().increment(
dynamo_compile_runtime_column_us,
time_spent_ns // 1000,
extra={
"compile_id": compile_id,
"is_runtime": True,
"is_forward": is_forward,
},
)
cumulative_time_spent_ns[event_name] += time_spent_ns


@overload
def compile_times(repr: Literal["str"], aggregate: bool = False) -> str:
Expand Down Expand Up @@ -858,7 +892,7 @@ class CompilationMetrics:
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
triton_compile_time_us: Optional[int] = None
runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument
runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument
runtime_triton_autotune_time_us: Optional[int] = None
dynamo_compile_time_before_restart_us: Optional[int] = None
cuda_synchronize_time_us: Optional[int] = None # TODO: instrument
distributed_ephemeral_timeout_us: Optional[int] = None
Expand All @@ -882,6 +916,8 @@ class CompilationMetrics:
triton_version: Optional[str] = None
feature_usage: Optional[dict[str, bool]] = None
compile_time_autotune_time_us: Optional[int] = None
is_runtime: Optional[bool] = False
gc_time_us: Optional[int] = None


DEFAULT_COMPILATION_METRICS_LIMIT = 64
Expand Down Expand Up @@ -1022,8 +1058,14 @@ def safe_str(item: Any) -> str:
inductor_fx_remote_cache_backend_type = None
remote_cache_version = None

# Populate the compile_id from the metrics context if it's set. Otherwise
# look for it in the compile context.
compile_id = metrics.get("compile_id")
if not compile_id:
compile_id = torch._guards.CompileContext.current_compile_id()

common_metrics = {
"compile_id": str(torch._guards.CompileContext.current_compile_id()),
"compile_id": str(compile_id) if compile_id else None,
"start_time_us": start_time_ns // 1000,
"end_time_us": end_time_ns // 1000,
"duration_us": (end_time_ns - start_time_ns) // 1000,
Expand Down Expand Up @@ -1066,10 +1108,12 @@ def safe_str(item: Any) -> str:
)
_compilation_metrics.append(compilation_metrics)

if compilation_metrics.is_forward:
name = "compilation_metrics"
else:
name = "bwd_compilation_metrics"
name = "compilation_metrics"
if compilation_metrics.is_forward is False:
name = "bwd_" + name
if compilation_metrics.is_runtime is True:
name = name + "_runtime"

torch._logging.trace_structured(
name,
lambda: {
Expand All @@ -1081,6 +1125,10 @@ def safe_str(item: Any) -> str:
# without making it inconsistent with compilation metrics itself, so
# we ignore the (hopefully small) time spent logging compilation metrics
record_logging_overhead=False,
# These may be runtime logs, e.g., runtime autotunning, so we provide
# the CompileId from the compilation metrics in case it's not available
# in the current trace.
compile_id=compile_id,
)

# If there's a chromium event in flight, add the CompilationMetrics to it.
Expand All @@ -1093,6 +1141,7 @@ def safe_str(item: Any) -> str:

# record_compilation_metrics is called by the singleton MetricsContext exit handler.
_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics)
_RUNTIME_METRICS_CONTEXT = RuntimeMetricsContext(on_exit=record_compilation_metrics)


def set_compilation_metrics_limit(new_size: int) -> None:
Expand Down Expand Up @@ -1196,15 +1245,18 @@ def log_event_start(
time_ns: int,
metadata: Dict[str, Any],
log_pt2_compile_event: bool = False,
compile_id: Optional[CompileId] = None,
) -> None:
"""
Logs the start of a single event.
:param str event_name Name of event to appear in trace
:param time_ns Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
:param log_pt_compile_event: If True, log to pt2_compile_events
:param compile_id: Explicit compile_id (rather than using the current context)
"""
compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id
compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
metadata["compile_id"] = str(compile_id)
self._log_timed_event(
event_name,
time_ns,
Expand Down Expand Up @@ -1234,16 +1286,20 @@ def log_event_end(
metadata: Dict[str, Any],
start_time_ns: int,
log_pt2_compile_event: bool,
compile_id: Optional[CompileId] = None,
) -> None:
"""
Logs the end of a single event. This function should only be
called after log_event_start with the same event_name.
:param event_name: Name of event to appear in trace
:param time_ns: Timestamp in nanoseconds
:param metadata: Any extra metadata associated with this event
:param start_time_ns: The start time timestamp in nanoseconds
:param log_pt_compile_event: If True, log to pt2_compile_events
:param compile_id: Explicit compile_id (rather than using the current context)
"""
compile_id = str(torch._guards.CompileContext.current_compile_id())
metadata["compile_id"] = compile_id
compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
metadata["compile_id"] = str(compile_id)

# Grab metadata collected during event span
all_event_data = self.get_event_data()
Expand Down
Loading