-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrope.py
158 lines (139 loc) · 5.17 KB
/
rope.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Optional, Tuple, Union
import torch
from rope_triton import (
rope_forward as forward_v1,
rope_backward as backward_v1,
)
from rope_triton_v2 import (
rope_forward as forward_v2,
rope_backward as backward_v2,
)
class RotaryPositionEmbeddingHalf(torch.nn.Module):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
rotary_percent: float
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
"""
super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (
10000
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
)
)
self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
offset: int, default = 0
fixed offset for freqencies
"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if (self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None):
if (max_seq_len >
self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor):
# dynamic linear scaling (length > position we have learned)
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
else:
# fixed linear scaling
seq *= 1 / self.seq_len_interpolation_factor
freqs = torch.einsum('i , j -> i j', seq, self.inv_freq)
# emb [seq_length, .., dim/2]
return freqs.reshape(freqs.size(0), 1, 1, freqs.size(1))
class FusedRoPEFuncV1(torch.autograd.Function):
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
) -> torch.Tensor:
if tensor_format == "sbhd":
output = forward_v1(t, freqs, False)
elif tensor_format == "bshd":
output = forward_v1(
t.transpose(0, 1), freqs, True
).transpose(0, 1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs)
ctx.tensor_format = tensor_format
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
freqs, = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = backward_v1(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = backward_v1(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None
class FusedRoPEFuncV2(torch.autograd.Function):
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
) -> torch.Tensor:
if tensor_format == "sbhd":
output = forward_v2(t, freqs, False)
elif tensor_format == "bshd":
output = forward_v2(
t.transpose(0, 1), freqs, True
).transpose(0, 1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs)
ctx.tensor_format = tensor_format
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
freqs, = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = backward_v2(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = backward_v2(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None