Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rope_benchmark #3550

Merged
merged 48 commits into from
Jan 14, 2025
Merged

rope_benchmark #3550

merged 48 commits into from
Jan 14, 2025

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Dec 10, 2024

Rope benchmark extracted from lightning trace.

TODO:

  • add iobytes measurement for benchmarks.

}


@pytest.mark.parametrize(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only part that's worth reviewing.

code above were directly dumped from Kevin's rope example script. (Note that I have to update the script with nv_enable_matmul in thunder.jit, otherwise we are seeing segmentation at nvfuser definition level)

@jjsjann123
Copy link
Collaborator Author

I also want to add another toy example where we'll sweep on the batch size. But I'll do that in a separate PR.

@naoyam
Copy link
Collaborator

naoyam commented Dec 10, 2024

@Priya2698 is adding the Thunder backend #3394. Does it mean we can just have the forward functions?

@Priya2698
Copy link
Collaborator

@Priya2698 is adding the Thunder backend #3394. Does it mean we can just have the forward functions?

We will also benchmark backward pass with Thunder backend.

@naoyam
Copy link
Collaborator

naoyam commented Dec 10, 2024

@Priya2698 is adding the Thunder backend #3394. Does it mean we can just have the forward functions?

We will also benchmark backward pass with Thunder backend.

Yes, so, we don't need to have the backward implementations explicitly, right?

@jjsjann123 jjsjann123 marked this pull request as draft December 10, 2024 21:28
@jjsjann123
Copy link
Collaborator Author

Looking at the thunder-nvfuser timing.

Strangely the benchmark number doesn't match with the benchmark from kevin's example.
This is from the measurement from pytest

Name (time in us)                                                                                       Min                   Max                  Mean            StdDev                Median               IQR            Outliers         OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='hf_qwen2_rope']              204.8290 (1.0)        212.5130 (1.0)        207.1972 (1.0)      2.5573 (2.49)       206.0485 (1.0)      4.0260 (4.17)          2;0  4,826.3200 (1.0)          10           1
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='hf_mistral_nemo_rope']       320.3510 (1.56)       324.3850 (1.53)       322.8819 (1.56)     1.3519 (1.32)       322.8555 (1.57)     1.8470 (1.91)          3;0  3,097.1076 (0.64)         10           1
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='hf_qwen2_rope']              356.9320 (1.74)       360.3840 (1.70)       357.8536 (1.73)     1.0271 (1.0)        357.7265 (1.74)     0.9920 (1.03)          1;1  2,794.4388 (0.58)         10           1
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='hf_mistral_nemo_rope']       428.8940 (2.09)       432.8350 (2.04)       430.9671 (2.08)     1.1889 (1.16)       431.0560 (2.09)     1.8540 (1.92)          3;0  2,320.3627 (0.48)         10           1
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='hf_phi3_rope']               548.0630 (2.68)       554.1090 (2.61)       552.0020 (2.66)     1.6203 (1.58)       552.3545 (2.68)     0.9650 (1.0)           2;2  1,811.5876 (0.38)         10           1
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='llama_2_7b_hf_rope']         621.6160 (3.03)       626.1340 (2.95)       623.5093 (3.01)     1.6043 (1.56)       623.0065 (3.02)     2.3690 (2.45)          4;0  1,603.8253 (0.33)         10           1
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='hf_phi3_rope']             1,022.0870 (4.99)     1,028.2720 (4.84)     1,024.4110 (4.94)     2.0313 (1.98)     1,024.3360 (4.97)     3.5130 (3.64)          2;0    976.1707 (0.20)         10           1
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='llama_2_7b_hf_rope']       1,308.1660 (6.39)     1,313.6600 (6.18)     1,310.4751 (6.32)     2.0083 (1.96)     1,310.5750 (6.36)     3.5940 (3.72)          5;0    763.0820 (0.16)         10           1
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='llama_3_8B_rope']          1,373.1600 (6.70)     1,382.4350 (6.51)     1,377.5739 (6.65)     2.3928 (2.33)     1,377.8270 (6.69)     2.2130 (2.29)          2;1    725.9139 (0.15)         10           1
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='llama_3_8B_rope']          1,925.9490 (9.40)     1,936.4170 (9.11)     1,931.5364 (9.32)     2.8123 (2.74)     1,931.2535 (9.37)     2.3720 (2.46)          3;1    517.7226 (0.11)         10           1

But if I run the manual rope_example, I'm getting these

root@a9fb56dcd91f:/volume/rope/rope_examples# python hf_phi3.py --execs Thunder-nvFuser
                             Model  Batch-Size  ...  Forward-Time(ms) Backward-Time(ms)
0  microsoft/Phi-3.5-mini-instruct           1  ...             0.597             0.739
root@a9fb56dcd91f:/volume/rope/rope_examples# python hf_qwen2.py --execs Thunder-nvFuser
                      Model  Batch-Size  ...  Forward-Time(ms) Backward-Time(ms)
0  Qwen/Qwen2.5-7B-Instruct           1  ...             0.397             0.507
root@a9fb56dcd91f:/volume/rope/rope_examples# python hf_mistral_nemo.py --execs Thunder-nvFuser
                              Model  Batch-Size  ...  Forward-Time(ms) Backward-Time(ms)
0  mistralai/Mistral-Nemo-Base-2407           1  ...             0.593             0.322
root@a9fb56dcd91f:/volume/rope/rope_examples# python lit_gpt_models.py --execs Thunder-nvFuser
           Model  Batch-Size  Sequence-Length         Executor  Forward-Time(ms)  Backward-Time(ms)
0  Llama-2-7b-hf           2             4096  Thunder-nvFuser             0.629              0.960
        Model  Batch-Size  Sequence-Length         Executor  Forward-Time(ms)  Backward-Time(ms)
0  Llama-3-8B           2             8192  Thunder-nvFuser             1.383              1.567

I'll double check the measurement script, as well as compile options (i.e. thunder trace options).

We need to do the same sanity check for torchcompile later.

@jjsjann123
Copy link
Collaborator Author

Is the thunder-torch.compile what we should be using in our benchmark as well, I'm asking since we do not have that executor in pytest benchmark yet.

That's what I heard from @kevinstephano.

Noted. I'll add another executor.

@jjsjann123
Copy link
Collaborator Author

I realized that Kevin's benchmark script has been updated to measure profiler time as well and I was two commits behind that. The previous discrepancy was coming from the different measurement.

@jjsjann123
Copy link
Collaborator Author

With updated manual benchmark, we are making apple to apple comparison now.

On h100

manual benchmark

Model Executor Forward-Kernels Forward-Time(ms) Backward Kernels Backward-Time(ms)
Llama-2-7b-hf Thunder-nvFuser 4 0.346 5 0.504
Llama-2-7b-hf Thunder-torch.compile 1 0.093 2 0.277
Llama-3-8B Thunder-nvFuser 5 0.733 5 0.814
Llama-3-8B Thunder-torch.compile 2 0.162 3 0.616
microsoft/Phi-3.5-mini-instruct Thunder-nvFuser 7 0.298 6 0.383
microsoft/Phi-3.5-mini-instruct Thunder-torch.compile 6 0.084 2 0.228
mistralai/Mistral-Nemo-Base-2407 Thunder-nvFuser 8 0.163 6 0.136
mistralai/Mistral-Nemo-Base-2407 Thunder-torch.compile 9 0.074 4 0.103
Qwen/Qwen2.5-7B-Instruct Thunder-nvFuser 5 0.109 8 0.318
Qwen/Qwen2.5-7B-Instruct Thunder-torch.compile 4 0.049 5 0.157

pytest benchmark

Name (time in us)                                                                                                Median
----------------------------------------------------------------------------------------------------------------------------------
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='llama_2_7b_hf_rope']                     347.5185 (6.90)
test_rope_variations_fwd_benchmark[executor='thunder-torchcompile'-rope_variation='llama_2_7b_hf_rope']         93.2800 (1.85)
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='llama_2_7b_hf_rope']                     707.8535 (14.06)
test_rope_variations_bwd_benchmark[executor='thunder-torchcompile'-rope_variation='llama_2_7b_hf_rope']        471.2945 (9.36)
 
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='llama_3_8B_rope']                        735.9030 (14.62)
test_rope_variations_fwd_benchmark[executor='thunder-torchcompile'-rope_variation='llama_3_8B_rope']           161.2320 (3.20)
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='llama_3_8B_rope']                      1,017.9635 (20.22)
test_rope_variations_bwd_benchmark[executor='thunder-torchcompile'-rope_variation='llama_3_8B_rope']           809.3740 (16.08)
 
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='hf_phi3_rope']                           363.3590 (7.22)    
test_rope_variations_fwd_benchmark[executor='thunder-torchcompile'-rope_variation='hf_phi3_rope']               83.9865 (1.67)    
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='hf_phi3_rope']                           471.2105 (9.36)    
test_rope_variations_bwd_benchmark[executor='thunder-torchcompile'-rope_variation='hf_phi3_rope']              374.9610 (7.45)    
 
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='hf_mistral_nemo_rope']                   175.7105 (3.49)    
test_rope_variations_fwd_benchmark[executor='thunder-torchcompile'-rope_variation='hf_mistral_nemo_rope']       74.0490 (1.47)
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='hf_mistral_nemo_rope']                   178.3185 (3.54)
test_rope_variations_bwd_benchmark[executor='thunder-torchcompile'-rope_variation='hf_mistral_nemo_rope']      162.8475 (3.24)
 
test_rope_variations_fwd_benchmark[executor='thunder'-rope_variation='hf_qwen2_rope']                          120.9745 (2.40)
test_rope_variations_fwd_benchmark[executor='thunder-torchcompile'-rope_variation='hf_qwen2_rope']              50.3365 (1.0)
test_rope_variations_bwd_benchmark[executor='thunder'-rope_variation='hf_qwen2_rope']                          381.1360 (7.57)
test_rope_variations_bwd_benchmark[executor='thunder-torchcompile'-rope_variation='hf_qwen2_rope']             211.7920 (4.21)
----------------------------------------------------------------------------------------------------------------------------------

fwd time mostly matches, (except for hf_phi3, that's because we are enabling matmul in pytest benchmark, which is not enabled by manual benchmark).
bwd time looks about right, pytest benchmark has higher time coming from the grad accumulation kernel.

):
kwargs = {}
if executor == "thunder":
kwargs["nv_enable_matmul"] = True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this is where I'm enabling matmul in nvfuser.

This gives us a single fusion region, I believe is something we would like. cc'ing @naoyam

return thunder.jit(
fwd_fn, nv_enable_bookend=False, executors=[nvfuserex], **kwargs
)
if executor == "thunder-torchcompile":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thunder-torchcompile is the config we wanted for the rope comparison. Not sure if this is something we would also like to enable for other benchmarks. cc'ing @Priya2698

Copy link
Collaborator

@Priya2698 Priya2698 Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not enable it for other benchmarks just yet. That would further increase the weekly CI timings, so let's keep it to RoPE for now. We can revisit which executors to run in nightly/weekly.

def with_executor(executor: str, fwd_fn: Callable) -> Callable:
assert executor in ["eager", "torchcompile", "thunder"]
def with_executor(executor: str, fwd_fn: Callable, **kwargs) -> Callable:
assert executor in ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert executor in ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
assert executor in ["eager", "torchcompile", "thunder-nvfuser", "thunder-torchcompile"]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it as you are suggesting a renaming.

I can do that in a separate PR to make reviewing easier.

@@ -221,9 +226,9 @@ def set_metrics(
% Peak Bandwidth (SOL): 100 * Bandwidth /PEAK_BANDWIDTH
"""
if not iobytes:
if isinstance(inputs, torch.Tensor):
if not isinstance(inputs, Iterable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this changed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code below is:

            for inp in inputs:
                if isinstance(inp, torch.Tensor):
                    iobytes += inp.element_size() * inp.numel()

So here we are really checking if inputs should be Iterable. Same thing applies to outputs

For backward checks, we have outputs=None, which would cause an exception when we don't provide iobytes. It doesn't apply any more here, since we are providing iobytes for rope. But I think it's still a legit fix.

@pytest.mark.parametrize(
"executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
)
def test_rope_variations_fwd_benchmark(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we also have test_rope_benchmark, which is separate from test_rope_variations_fwd_benchmark and test_rope_variations_bwd_benchmark. What does test_rope_benchmark evaluate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think those are coming from @wujingyue 's experiment with rope performance earlier last year. We just have some manual definition of the rope module in llama examples.

I'm not sure if those benchmarks are still relevant? tagging @wujingyue .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently, the new benchmarks covers more variations and is more realistic (captured from Thunder traces). So I don't see a reason to keep test_rope_benchmark, which is only one variation and is forward only. The without_cat variation could be useful, but since no model implementations have adopted this trick and nvFuser is getting better on cat I don't see a reason to keep that either.

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some naming nits

benchmarks/python/test_rope.py Outdated Show resolved Hide resolved
benchmarks/python/test_rope.py Outdated Show resolved Hide resolved
for i in range(1, len(outputs)):
output += outputs[i]

# NOTE: the iobytes is computed based on how thunder autograd worked. So this is just
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your previous point about the IObytes computation potentially being different between the executors. I discussed this with @kevinstephano but off-hand I am not sure of a more robust way. I will open an issue to look into it. For this PR, the nvfuser-definition based computation LGTM since that is consistent with the other benchmarks.

@Priya2698
Copy link
Collaborator

Another question, maybe to @Priya2698: Can we enable the result verification by default? I remember there's still a tolerance issue, but for these RoPE benchmarks since there's almost no reduction (there's some in the backward cases), maybe verification would work fine?

Any benchmark can add validation to the benchmark. For nvfuser (manual fusion definitions), we validate against torch reference by default (it can be turned off using --disable-validation). Since these benchmarks are generated using the torch definitions, we don't have a baseline result to validate against here.

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments. The benchmarks look good to me overall.
For the variations: are there any good references we can add as comments?
Please also post the final numbers and the numbers from Kevin's script in the PR description for posterity.

"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these numbers come from? Can you add a reference if there is one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those were just dumped config from the model that we pulled.

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3.5-mini-instruct")
print(model.config)

So we don't have to pull the real model. Does this sound right to you @kevinstephano ?

@naoyam
Copy link
Collaborator

naoyam commented Jan 13, 2025

Another question, maybe to @Priya2698: Can we enable the result verification by default? I remember there's still a tolerance issue, but for these RoPE benchmarks since there's almost no reduction (there's some in the backward cases), maybe verification would work fine?

Any benchmark can add validation to the benchmark. For nvfuser (manual fusion definitions), we validate against torch reference by default (it can be turned off using --disable-validation). Since these benchmarks are generated using the torch definitions, we don't have a baseline result to validate against here.

We run the benchmarks in the eager mode as well. Why can't we just use the results as the reference?

@jjsjann123
Copy link
Collaborator Author

!test

@wujingyue wujingyue removed their request for review January 14, 2025 18:04
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @jjsjann123.

@Priya2698 I still would like to have result validations. Resize indexing in these fusions tends to be complex, and I suspect compute-sanitizer would not be very helpful.

@xwang233, @jjsjann123 It would be really good to have these new benchmark cases showing up in the performance dashboard.

@jjsjann123 jjsjann123 merged commit c8817e0 into main Jan 14, 2025
43 of 44 checks passed
@jjsjann123 jjsjann123 deleted the jjsjann123/rope_benchmark branch January 14, 2025 22:58
@Priya2698
Copy link
Collaborator

@Priya2698 I still would like to have result validations. Resize indexing in these fusions tends to be complex, and I suspect compute-sanitizer would not be very helpful.

We can add validation here, same as we do in nvfuser benchmarks:

if not disable_validation:
        eager_output = rope(inputs)
        fd.validate(inputs, [eager_output])

Instead of fd.validate we directly use torch.assert for other executors. Are you referring to any other addition in the infrastructure? The disable_validation parameter is nice to disable these validations when needed, else, the validation runs by default.

We do need to compute the eager output for other executors and cannot reuse it from the eager executor, since all tests run independently and sharing variables between the tests in pytest is not straightforward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants