Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type issue introduced by #28 #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

HaileyStorm
Copy link
Contributor

@HaileyStorm HaileyStorm commented Oct 7, 2024

Commit #28 changed apply_rotary_embed to have dtype parameter with default float32, and forces attention softmax to be done float32. Since attention doesn't specify the dtype parameter when calling apply_rotary_embed, and the output matmul doesn't convert back from float32 to match the values type, this is an issue if you're running BF16.

This specifies the existing xq.dtype for the dtype parameter when calling apply_rotary_embed (alternatively, we could cast keys to float32 in scores = torch.matmul(xq, keys)), and casts scores to match values at the output matmul.

Commit xjdr-alt#28 changed `apply_rotary_embed` to have dtype parameter with default float32, and forces attention softmax to be done float32. Since `attention` doesn't specify the dtype parameter when calling `apply_rotary_embed`, and output matmul doesn't convert back from float32 to match the values type, this is an issue if you're running BF16.

This specifies the existing xq.dtype for the dtype parameter when calling `apply_rotary_embed` (alternatively, we could cast keys to float32 in `scores = torch.matmul(xq, keys)`), and converts the scores to match values at the output matmul.
@xjdr-alt
Copy link
Owner

xjdr-alt commented Oct 8, 2024

@Arrabonae could you take a look

Copy link
Contributor

@citizenhicks citizenhicks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested this, works well. thanks for spotting this issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants