Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add layernorm backward pass #697

Merged
merged 5 commits into from
Jan 15, 2025
Merged

add layernorm backward pass #697

merged 5 commits into from
Jan 15, 2025

Conversation

scxiao
Copy link

@scxiao scxiao commented Jan 9, 2025

This PR is to add the backward pass to the layernorm op for the perf kernel of the amd rocm fork. The following changes are made:

  • Added a backward pass of LayerNorm
  • Enhanced the backward kernel to handle more input shapes
  • Implemented Layernorm class following the standard interface
  • Added unit tests to verify the backward kernel correctness.
  • Added input shapes retrieved from model_config.json file

As for performance, the forward pass becomes slower, due to the following reasons:

  • Added arguments "mean" and "rstd" and time needed to store results to these tensors
  • The forward pass need time to create these two tensors
  • The forward pass need time to backup a few tensors in the function "save_for_backward"
  • If removing these two arguments, kernel ISA is the same as before

The core Triton is a small number of people, and we receive many PRs (thank
you!). To help us review your code more quickly, if you are a new
contributor (less than 3 PRs merged) we ask that you complete the following
tasks and include the filled-out checklist in your PR description.

Complete the following tasks before sending your PR, and replace [ ] with
[x] to indicate you have done them.

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Copy link
Collaborator

@vgokhale vgokhale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Shucai!

@vgokhale vgokhale merged commit 1006241 into main_perf Jan 15, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants