-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_te_rope.py
89 lines (77 loc) · 2.7 KB
/
test_te_rope.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
from typing import Callable, Dict, Tuple, Union
from transformer_engine.pytorch.attention import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
def get_tol(dtype: torch.dtype) -> Dict:
if dtype == torch.bfloat16:
return dict(atol=1e-2, rtol=1e-2)
elif dtype == torch.float16:
return dict(atol=1e-3, rtol=1e-3)
return dict(atol=1e-5, rtol=1.3e-6)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
t = torch.ones_like(output)
return torch.sum(output * t)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_fused_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
transpose: Union[Tuple, None],
tensor_format: str,
loss_func: Callable,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
if transpose:
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)
# unfused
output_unfused = apply_rotary_pos_emb(
t, emb, tensor_format=tensor_format, fused=False
)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
fused=True,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
assert output_fused.is_contiguous()