Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial ring flash attention support #1266

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,12 @@ def build_train_valid_test_data_iterators(neox_args):
else:
pipe_load = True

# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0 and pipe_load:
# Data loader only on rank 0 of each model/sequence parallel group.
if (
mpu.get_model_parallel_rank() == 0
and pipe_load
and mpu.get_seq_parallel_rank() == 0
):
# Number of train/valid/test samples.
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
Expand Down Expand Up @@ -441,6 +445,11 @@ def build_train_valid_test_data_iterators(neox_args):
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group(),
)
torch.distributed.broadcast(
flags,
mpu.get_seq_parallel_src_rank(),
group=mpu.get_seq_parallel_group(),
)
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()
Expand Down
20 changes: 12 additions & 8 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,18 @@ def _initialize_distributed(neox_args):
# Setup 3D topology.
pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
sp = (
neox_args.sequence_parallel_size if neox_args.sequence_parallel_size >= 1 else 1
)
assert (
neox_args.world_size % (pp * mp) == 0
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
dp = neox_args.world_size // (pp * mp)

from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
neox_args.world_size % (pp * mp * sp) == 0
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}, sp={sp}"
dp = neox_args.world_size // (pp * mp * sp)
from deepspeed.runtime.pipe.topology import ProcessTopology

# this does pipe on the most outside, then data, then model.
# PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order.
topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)
# With 4D parallelism, we have 4 dimensions: pipe, data, model, sequence
# So we need to define it manually...
topo = ProcessTopology(axes=["pipe", "data", "model", "seq"], dims=[pp, dp, mp, sp])

# Offset base seeds for the interior pipeline stages.
# TODO: adjust last stage too once IO is improved.
Expand All @@ -186,6 +188,8 @@ def _initialize_distributed(neox_args):
else:
mpu.initialize_model_parallel(
neox_args.model_parallel_size,
neox_args.pipe_parallel_size,
neox_args.sequence_parallel_size,
topology=topo,
fp32_allreduce=neox_args.fp32_allreduce,
)
Expand Down
16 changes: 15 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,21 @@ def cross_entropy(output, labels, _fp16=False):
else:
losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
loss_mask_sum = loss_mask.sum()
if mpu.get_seq_parallel_world_size() > 1:
torch.distributed.all_reduce(
loss_mask_sum,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_seq_parallel_group(),
)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
torch.distributed.all_reduce(
loss,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_seq_parallel_group(),
)
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
return loss


Expand Down
23 changes: 22 additions & 1 deletion megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import math
import megatron.mpu as mpu


class SinusoidalPositionalEmbedding(torch.nn.Module):
Expand All @@ -37,7 +38,13 @@ def forward(self, x, seq_dim=1):

class RotaryEmbedding(torch.nn.Module):
def __init__(
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False
self,
dim,
max_seq_len,
base=10000,
precision=torch.half,
save_inv_freqs=False,
zigzag=True,
):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
Expand All @@ -49,6 +56,7 @@ def __init__(
self.max_seq_len = max_seq_len
self.base = base
self.dim = dim
self.zigzag = zigzag # seq parallel zigzag

# precompute cos_cached, sin_cached in fp32
cos_cached, sin_cached, inv_freq = self._prepare_cache(
Expand All @@ -64,6 +72,19 @@ def _prepare_cache(self, seq_len, precision, base):
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim))

t = torch.arange(seq_len).type_as(inv_freq)
if mpu.get_seq_parallel_world_size() > 1:
if not self.zigzag:
t_chunks = torch.chunk(t, mpu.get_seq_parallel_world_size())
t = t_chunks[mpu.get_seq_parallel_rank()].contiguous()
else:
t_chunks = torch.chunk(t, 2 * mpu.get_seq_parallel_world_size())
t = torch.cat(
(
t_chunks[mpu.get_seq_parallel_rank()],
t_chunks[-(mpu.get_seq_parallel_rank() + 1)],
),
dim=0,
).contiguous()
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

Expand Down
103 changes: 101 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def __init__(
self.rope_fusion = neox_args.rope_fusion
self.attention_type = neox_args.attention_config[layer_number]
self.use_flash_attention = self.attention_type == "flash"
self.use_ring_attention = self.attention_type == "ring"
self.use_triton = (
self.use_flash_attention
and self.pos_emb == "alibi"
Expand All @@ -472,7 +473,7 @@ def __init__(
>= packaging.version.Version("2.4.0.post1")
)
)
self.sparse = self.attention_type not in ("global", "flash")
self.sparse = self.attention_type not in ("global", "flash", "ring")

if self.gqa:
assert not self.sparse
Expand Down Expand Up @@ -501,6 +502,12 @@ def __init__(
self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
self.flash_qkv_fn = flash_attn_func
self.flash_varlen_qkv_fn = flash_attn_varlen_func
elif self.use_ring_attention:
from ring_flash_attn.zigzag_ring_flash_attn import (
zigzag_ring_flash_attn_func,
)

self.ring_attn_fn = zigzag_ring_flash_attn_func
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
Expand Down Expand Up @@ -748,6 +755,96 @@ def flash_attention(self, query_layer, key_layer, value_layer):

return matmul_result

def ring_attention(self, query_layer, key_layer, value_layer):
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)

# [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
key_layer = key_layer.transpose(0, 1).reshape(
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
)
value_layer = value_layer.transpose(0, 1).reshape(
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
)

# [sq, b, np, hn] -> [b, sq, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0], output_size[2], output_size[1], -1
)

# only pass in window_size or alibi_slopes kwarg
# if we use Sliding Window Attention / AliBi.
# Flash attn defaults to (-1,-1), or
# does not have this kwarg prior to v2.3.0
extra_kwargs = (
{"window_size": (self.sliding_window_width, -1)}
if self.sliding_window_width is not None
else {}
)
if self.pos_emb == "alibi":
extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to(
query_layer.device
).to(torch.float32)

if not self.training:
batch_size = output_size[0]
max_seqlen_q = output_size[2]
max_seqlen_k = output_size[3]

cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device,
)

cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * max_seqlen_k,
step=max_seqlen_k,
dtype=torch.int32,
device=key_layer.device,
)

q_shape = query_layer.shape
k_shape = key_layer.shape
v_shape = value_layer.shape
is_causal = max_seqlen_q == max_seqlen_k
output = self.ring_attn_fn(
query_layer,
key_layer,
value_layer,
0.0,
softmax_scale=None,
causal=is_causal,
group=mpu.get_seq_parallel_group(),
**extra_kwargs,
)
output = output.reshape(q_shape)
else:
output = self.ring_attn_fn(
query_layer,
key_layer,
value_layer,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
group=mpu.get_seq_parallel_group(),
**extra_kwargs,
)

matmul_result = output
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)

return matmul_result

def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
# TODO: sparse attn dropout?
# TODO: pad to block size
Expand Down Expand Up @@ -843,7 +940,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
value_layer = value_layer.view(*new_kv_shape)

# if not using Flash attention, we repeat K/V heads to match Q head counts
if not self.use_flash_attention:
if not (self.use_flash_attention or self.use_ring_attention):
key_layer = torch.repeat_interleave(
key_layer,
repeats=int(
Expand Down Expand Up @@ -957,6 +1054,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None):

if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif self.use_ring_attention:
context_layer = self.ring_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
Expand Down
10 changes: 10 additions & 0 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@
from megatron.model.fused_softmax import SoftmaxFusionTypes
from types import GeneratorType
import torch.distributed as dist
from megatron.mpu import (
get_seq_parallel_group,
get_seq_parallel_src_rank,
get_seq_parallel_rank,
get_seq_parallel_world_size,
)
from megatron.mpu.mappings import (
_GatherFromSeqParallelRegion,
_ScatterToSeqParallelRegion,
)


def get_params_for_weight_decay_optimization(module, neox_args):
Expand Down
Loading