Skip to content

Commit

Permalink
Modify the compile parameter in baseline benchmarks to executor (#…
Browse files Browse the repository at this point in the history
…3350)

This PR is the first step in adding `thunder.jit` benchmarks.
The major change is modifying the `compile` parameter to `executor` with
values `eager`, `torchcompile`, `thunder`. This PR does not introduce
any new thunder benchmarks (to be done in next PR).

CC: @xwang233 for dashboard changes.

Issue: #2718
  • Loading branch information
Priya2698 authored Nov 8, 2024
1 parent 6abf310 commit 96d64b6
Show file tree
Hide file tree
Showing 34 changed files with 287 additions and 210 deletions.
64 changes: 29 additions & 35 deletions benchmarks/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,45 +96,39 @@ def pytest_configure(config):

def pytest_collection_modifyitems(session, config, items):
"""
The baseline benchmarks use `compile` parameter:
compile = false: Eager mode benchmark
compile = true: torch.compile benchmark
The baseline benchmarks use `executor` parameter with
values ["eager", "torchcompile", "thunder"] that are optionally
run using `--benchmark-{executor}` flag. They are skipped by
default.
"""
run_eager = config.getoption("--benchmark-eager")
run_thunder = config.getoption("--benchmark-thunder")
run_torchcompile = config.getoption("--benchmark-torchcompile")

from nvfuser.pytorch_utils import retry_on_oom_or_skip_test

executors = ["eager", "torchcompile", "thunder"]

def get_test_executor(item) -> str | None:
if hasattr(item, "callspec") and "executor" in item.callspec.params:
test_executor = item.callspec.params["executor"]
assert (
test_executor in executors
), f"Expected executor to be one of 'eager', 'torchcompile', 'thunder', found {test_executor}."
return test_executor
return None

executors_to_skip = []

for executor in executors:
if not config.getoption(f"--benchmark-{executor}"):
executors_to_skip.append(executor)

for item in items:
item.obj = retry_on_oom_or_skip_test(item.obj)

if not run_eager:
skip_eager = pytest.mark.skip(reason="need --benchmark-eager option to run")
for item in items:
# If the benchmark has compile=False parameter (eager mode), skip it.
if (
hasattr(item, "callspec")
and "compile" in item.callspec.params
and not item.callspec.params["compile"]
):
item.add_marker(skip_eager)

if not run_torchcompile:
skip_torchcompile = pytest.mark.skip(
reason="need --benchmark-torchcompile option to run"
)
for item in items:
# If the benchmark has compile=True parameter (torch.compile mode), skip it.
if (
hasattr(item, "callspec")
and "compile" in item.callspec.params
and item.callspec.params["compile"]
):
item.add_marker(skip_torchcompile)

if not run_thunder:
skip_thunder = pytest.mark.skip(reason="need --benchmark-thunder option to run")
for item in items:
if "thunder" in item.nodeid:
item.add_marker(skip_thunder)
test_executor = get_test_executor(item)

if test_executor is not None and test_executor in executors_to_skip:
item.add_marker(
pytest.mark.skip(
reason=f"need --benchmark-{test_executor} option to run."
)
)
18 changes: 10 additions & 8 deletions benchmarks/python/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,10 @@ def norm_fwd_baseline_benchmark(
size: tuple,
dtype: torch.dtype,
channels_last: bool,
compile: bool,
executor: str,
norm: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()

assert norm in ["batch_norm", "instance_norm"], NotImplementedError
Expand All @@ -453,10 +453,12 @@ def norm_fwd_baseline_benchmark(

norm_fwd_fn = batchnorm_fwd_fn if norm == "batch_norm" else instancenorm_fwd_fn

benchmark_fn = {"eager": norm_fwd_fn, "torchcompile": torch.compile(norm_fwd_fn)}

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(norm_fwd_fn) if compile else norm_fwd_fn,
benchmark_fn[executor],
[inputs, weight, bias, running_mean, running_var],
iobytes=norm_fwd_iobytes(size, dtype, norm),
)
Expand All @@ -467,10 +469,10 @@ def norm_bwd_baseline_benchmark(
size: tuple,
dtype: torch.dtype,
channels_last: bool,
compile: bool,
executor: str,
norm: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()

assert norm in ["batch_norm", "instance_norm"], NotImplementedError
Expand All @@ -491,13 +493,13 @@ def norm_bwd_baseline_benchmark(
norm_fwd_fn = batchnorm_fwd_fn if norm == "batch_norm" else instancenorm_fwd_fn

# Compile the fwd fn for torchcompile
norm_fwd_fn = torch.compile(norm_fwd_fn) if compile else norm_fwd_fn
output = norm_fwd_fn([inputs, weight, bias, running_mean, running_var])
fwd_fn = {"eager": norm_fwd_fn, "torchcompile": torch.compile(norm_fwd_fn)}
outputs = fwd_fn[executor]([inputs, weight, bias, running_mean, running_var])

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
unary_bwd_torch,
[output, grads],
[outputs, grads],
iobytes=norm_bwd_iobytes(size, dtype, norm),
)
6 changes: 3 additions & 3 deletions benchmarks/python/test_batchnorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def test_batchnorm_bwd_nvf_benchmark(
)


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=4))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("channels_last", [True, False])
def test_batchnorm_bwd_baseline_benchmark(
benchmark, size: tuple, dtype: torch.dtype, channels_last: bool, compile: bool
benchmark, size: tuple, dtype: torch.dtype, channels_last: bool, executor: str
):
norm_bwd_baseline_benchmark(
benchmark, size, dtype, channels_last, compile, "batch_norm"
benchmark, size, dtype, channels_last, executor, "batch_norm"
)
6 changes: 3 additions & 3 deletions benchmarks/python/test_batchnorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def test_batchnorm_fwd_nvf_benchmark(
)


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=4))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("channels_last", [True, False])
def test_batchnorm_fwd_baseline_benchmark(
benchmark, size: tuple, dtype: torch.dtype, channels_last: bool, compile: bool
benchmark, size: tuple, dtype: torch.dtype, channels_last: bool, executor: str
):
norm_fwd_baseline_benchmark(
benchmark, size, dtype, channels_last, compile, "batch_norm"
benchmark, size, dtype, channels_last, executor, "batch_norm"
)
13 changes: 9 additions & 4 deletions benchmarks/python/test_broadcast_add_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_bcast_add_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [bias, x])


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("bcast_axis", [0, 1], ids=["outer", "inner"])
Expand All @@ -101,9 +101,9 @@ def test_bcast_add_baseline_benchmark(
dtype: torch.dtype,
bcast_axis: int,
contiguous: bool,
compile: bool,
executor: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()
bias = torch.randn(size[1 - bcast_axis], dtype=dtype, device="cuda")
input_shape = size if contiguous else (size[1], size[0])
Expand All @@ -112,9 +112,14 @@ def test_bcast_add_baseline_benchmark(
x = x.t()
assert x.is_contiguous() == contiguous

benchmark_fn = {
"eager": bcast_add_fwd_fn,
"torchcompile": torch.compile(bcast_add_fwd_fn),
}

# Inputs and outputs are same as nvFuser, no need for manual IOByte computation
run_benchmark(
benchmark,
torch.compile(bcast_add_fwd_fn) if compile else bcast_add_fwd_fn,
benchmark_fn[executor],
[bias, x, bcast_axis],
)
15 changes: 9 additions & 6 deletions benchmarks/python/test_dropout_layernorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,16 @@ def test_dropout_layernorm_bwd_nvf_benchmark(
)


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_layernorm_bwd_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
compile: bool,
executor: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()

dropout_p = 0.2
Expand All @@ -217,13 +217,16 @@ def dropout_layernorm_fwd():
)

# Compile the fwd fn for torchcompile
fwd_fn = torch.compile(dropout_layernorm_fwd) if compile else dropout_layernorm_fwd
output = fwd_fn()
fwd_fn = {
"eager": dropout_layernorm_fwd,
"torchcompile": torch.compile(dropout_layernorm_fwd),
}
outputs = fwd_fn[executor]()

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
unary_bwd_torch,
[output, grads],
[outputs, grads],
iobytes=dropout_layernorm_bwd_iobytes(size, dtype),
)
13 changes: 9 additions & 4 deletions benchmarks/python/test_dropout_layernorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,16 @@ def test_dropout_layernorm_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, inputs)


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_layernorm_fwd_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
compile: bool,
executor: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()

dropout_p = 0.2
Expand All @@ -181,10 +181,15 @@ def test_dropout_layernorm_fwd_baseline_benchmark(
dropout_p,
]

benchmark_fn = {
"eager": dropout_layernorm_fwd,
"torchcompile": torch.compile(dropout_layernorm_fwd),
}

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(dropout_layernorm_fwd) if compile else dropout_layernorm_fwd,
benchmark_fn[executor],
inputs,
iobytes=dropout_layernorm_fwd_iobytes(size, dtype),
)
15 changes: 9 additions & 6 deletions benchmarks/python/test_dropout_rmsnorm_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,16 @@ def test_dropout_rmsnorm_bwd_nvf_benchmark(
)


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_rmsnorm_bwd_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
compile: bool,
executor: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()
dropout_p = 0.2
input1 = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True)
Expand All @@ -191,12 +191,15 @@ def dropout_rmsnorm_fwd():
output = weights * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-5)
return output

fwd_fn = torch.compile(dropout_rmsnorm_fwd) if compile else dropout_rmsnorm_fwd
output = fwd_fn()
fwd_fn = {
"eager": dropout_rmsnorm_fwd,
"torchcompile": torch.compile(dropout_rmsnorm_fwd),
}
outputs = fwd_fn[executor]()

run_benchmark(
benchmark,
unary_bwd_torch,
[output, grads],
[outputs, grads],
iobytes=dropout_rmsnorm_bwd_iobytes(size, dtype),
)
13 changes: 9 additions & 4 deletions benchmarks/python/test_dropout_rmsnorm_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,16 @@ def test_dropout_rmsnorm_fwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [input1, input2, weights])


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_dropout_rmsnorm_fwd_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
compile: bool,
executor: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()
dropout_p = 0.2

Expand All @@ -165,10 +165,15 @@ def test_dropout_rmsnorm_fwd_baseline_benchmark(
dropout_p,
]

benchmark_fn = {
"eager": dropout_rmsnorm_fwd,
"torchcompile": torch.compile(dropout_rmsnorm_fwd),
}

# Manually compute IOBytes: See PR #1725
run_benchmark(
benchmark,
torch.compile(dropout_rmsnorm_fwd) if compile else dropout_rmsnorm_fwd,
benchmark_fn[executor],
inputs,
iobytes=dropout_rmsnorm_fwd_iobytes(size, dtype),
)
15 changes: 9 additions & 6 deletions benchmarks/python/test_gelu_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ def test_gelu_bwd_nvf_benchmark(
run_benchmark(benchmark, fd.execute, [inputs, grads, bias])


@pytest.mark.parametrize("compile", [False, True], ids=["eager", "compile"])
@pytest.mark.parametrize("executor", ["eager", "torchcompile"])
@pytest.mark.parametrize("size", generate_input_sizes(dims=2))
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_gelu_bwd_baseline_benchmark(
benchmark,
size: tuple,
dtype: torch.dtype,
compile: bool,
executor: str,
):
if compile:
if executor == "torchcompile":
clear_dynamo_cache()
inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True)
bias = torch.ones(size[-1], device="cuda", dtype=dtype)
Expand All @@ -106,12 +106,15 @@ def test_gelu_bwd_baseline_benchmark(
def gelu_fwd():
return torch.nn.functional.gelu(inputs + bias, approximate="tanh")

fwd_fn = torch.compile(gelu_fwd) if compile else gelu_fwd
eager_output = fwd_fn()
fwd_fn = {
"eager": gelu_fwd,
"torchcompile": torch.compile(gelu_fwd),
}
outputs = fwd_fn[executor]()

run_benchmark(
benchmark,
unary_bwd_torch,
[eager_output, grads],
[outputs, grads],
iobytes=gelu_bwd_iobytes(size, dtype),
)
Loading

0 comments on commit 96d64b6

Please sign in to comment.