Skip to content

Commit

Permalink
rope_benchmark (#3550)
Browse files Browse the repository at this point in the history
1. Adding pytest test_rope benchmark. Running rope variations including:
  - llama_2_7b_hf_rope
  - llama_3_8B_rope
  - hf_qwen2_rope
  - hf_phi3_rope
  - hf_mistral_nemo_rope

2. Adding thunder-torchcompile as an executor
3. Adding kwargs with executors, which allow us to pass `nv_enable_matmul = True` to thunder.jit
4. Fixing `set_metrics` constructing Iterable to avoid exception throw in bwd, where no output is returned.
  • Loading branch information
jjsjann123 authored Jan 14, 2025
1 parent 41326a3 commit c8817e0
Show file tree
Hide file tree
Showing 5 changed files with 1,091 additions and 178 deletions.
9 changes: 7 additions & 2 deletions benchmarks/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def pytest_addoption(parser):
action="store_true",
help="Benchmarks torch.compile mode.",
)
parser.addoption(
"--benchmark-thunder-torchcompile",
action="store_true",
help="Benchmarks torch.compile mode.",
)

# pytest-benchmark does not have CLI options to set rounds/warmup_rounds for benchmark.pedantic.
# The following two options are used to overwrite the default values through CLI.
Expand Down Expand Up @@ -104,14 +109,14 @@ def pytest_collection_modifyitems(session, config, items):

from nvfuser.pytorch_utils import retry_on_oom_or_skip_test

executors = ["eager", "torchcompile", "thunder"]
executors = ["eager", "torchcompile", "thunder", "thunder-torchcompile"]

def get_test_executor(item) -> str | None:
if hasattr(item, "callspec") and "executor" in item.callspec.params:
test_executor = item.callspec.params["executor"]
assert (
test_executor in executors
), f"Expected executor to be one of 'eager', 'torchcompile', 'thunder', found {test_executor}."
), f"Expected executor to be one of 'eager', 'torchcompile', 'thunder', 'thunder-torchcompile', found {test_executor}."
return test_executor
return None

Expand Down
17 changes: 11 additions & 6 deletions benchmarks/python/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from collections.abc import Iterable
import pytest_benchmark
import torch
from torch.autograd import DeviceType
Expand Down Expand Up @@ -47,14 +48,18 @@ def unary_bwd_torch(inputs: List): # [output, grad_out]
inputs[0].backward(inputs[1], retain_graph=True)


def with_executor(executor: str, fwd_fn: Callable) -> Callable:
assert executor in ["eager", "torchcompile", "thunder"]
def with_executor(executor: str, fwd_fn: Callable, **kwargs) -> Callable:
assert executor in ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
if executor == "eager":
return fwd_fn
if executor == "torchcompile":
return torch.compile(fwd_fn)
return torch.compile(fwd_fn, **kwargs)
if executor == "thunder":
return thunder.jit(fwd_fn, nv_enable_bookend=False, executors=[nvfuserex])
return thunder.jit(
fwd_fn, nv_enable_bookend=False, executors=[nvfuserex], **kwargs
)
if executor == "thunder-torchcompile":
return thunder.jit(fwd_fn, executors=["torchcompile"], **kwargs)


def compute_total_iobytes(
Expand Down Expand Up @@ -221,9 +226,9 @@ def set_metrics(
% Peak Bandwidth (SOL): 100 * Bandwidth /PEAK_BANDWIDTH
"""
if not iobytes:
if isinstance(inputs, torch.Tensor):
if not isinstance(inputs, Iterable):
inputs = [inputs]
if isinstance(outputs, torch.Tensor):
if not isinstance(outputs, Iterable):
outputs = [outputs]

iobytes = 0
Expand Down
1 change: 0 additions & 1 deletion benchmarks/python/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,6 @@ def norm_bwd_baseline_benchmark(

norm_fwd_fn = batchnorm_fwd_fn if norm == "batch_norm" else instancenorm_fwd_fn

# Compile the fwd fn for torchcompile
fwd_fn = with_executor(executor, norm_fwd_fn)
fwd_inputs = [inputs, weight, bias, running_mean, running_var]
outputs = fwd_fn(fwd_inputs)
Expand Down
Loading

0 comments on commit c8817e0

Please sign in to comment.