Skip to content

Commit

Permalink
DID loop split for SDPA (#3711)
Browse files Browse the repository at this point in the history
In this PR, I explicitly parallelize the outputs `attn`, `log_sumexp` of
`sdpfa_fwd`. Sharding propagation for loop split does not work correctly
in this case at the moment.

---------

Co-authored-by: Jingyue Wu <[email protected]>
  • Loading branch information
Priya2698 and wujingyue authored Jan 21, 2025
1 parent f6a3b4d commit 7e3001b
Showing 1 changed file with 139 additions and 0 deletions.
139 changes: 139 additions & 0 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,145 @@ def assert_close(actual, expected):
assert_close(v_grad, head_parallelize(expected_v_grad))


@pytest.mark.skipif(
utils.is_pre_ampere(),
reason="Flash Attention is only supported on Ampere and newer devices.",
)
@pytest.mark.parametrize("qkv_format", [QkvFormat.BHSE, QkvFormat.BSHE])
@pytest.mark.mpi
def test_sdpa_loop_split(multidevice_test, qkv_format: QkvFormat):
d, b, s, h, e = multidevice_test.size, 2, 1024, 12, 768

if h % d != 0:
pytest.skip(f"We only support even split, so {h} has to be divisible by {d}.")
mesh = nvfuser.DeviceMesh(range(d))

class Model(FusionDefinition):
def __init__(self, qkv_format: QkvFormat):
super().__init__()
self._qkv_format = qkv_format

def definition(self) -> None:
match self._qkv_format:
case QkvFormat.BHSE:
stride_order = [3, 2, 1, 0]
case QkvFormat.BSHE:
stride_order = [3, 1, 2, 0]

self.q, self.k, self.v, self.out_grad = [
self.define_tensor(
shape=[b, h, s, e // h],
dtype=DataType.BFloat16,
stride_order=stride_order,
)
for _ in range(4)
]

# TODO(#3123): support sharded dropout and change this to a
# positive probability.
dropout_p = self.define_scalar(0.0, dtype=DataType.Double)
is_causal = self.define_scalar(True, dtype=DataType.Bool)
self.attn, self.log_sumexp, seed, offset = self.ops.sdpfa_fwd(
self.q, self.k, self.v, dropout_p, is_causal, scale=None
)

self.q_grad, self.k_grad, self.v_grad = self.ops.sdpfa_bwd(
self.out_grad,
self.q,
self.k,
self.v,
self.attn,
self.log_sumexp,
dropout_p,
is_causal,
seed,
offset,
scale=None,
)

self.add_output(self.attn)
for grad in [self.q_grad, self.k_grad, self.v_grad]:
self.add_output(grad)

def multidevice_schedule(self) -> None:
for t in [
self.q,
self.k,
self.v,
self.attn,
self.log_sumexp,
self.out_grad,
self.q_grad,
self.k_grad,
self.v_grad,
]:
self.sched._set_device_mesh(t, mesh)
self.sched.split(t, 1, d, False)
self.sched.parallelize(t, 1, nvfuser.ParallelType.mesh_x)
if self._qkv_format == QkvFormat.BSHE:
# The loop domain is: {i{B}, i{DIDx}, i{H//D}, i{S}, i{E//H}}
# Reorder i{S} in the allocation domain for BHSE: {i{DIDx}, i{B}, i{S}, i{H//D}, i{E//H}}
self.sched.reorder(t, {2: 3, 3: 2})
self.sched.set_allocation_as_loop(t)

torch.cuda.set_device(multidevice_test.local_rank)

def make_unsharded_tensor() -> torch.Tensor:
return torch.randn(b, h, s, e // h, dtype=torch.bfloat16, device="cpu")

q, k, v = [make_unsharded_tensor().requires_grad_() for _ in range(3)]
out_grad = make_unsharded_tensor()
sharded_q, sharded_k, sharded_v, sharded_out_grad = [
multidevice_test.shard_tensor(t, 1, mesh) for t in [q, k, v, out_grad]
]

with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
expected_out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=0.0, is_causal=True, scale=None
)
expected_out.backward(out_grad)
expected_q_grad, expected_k_grad, expected_v_grad = q.grad, k.grad, v.grad

fd = Model(qkv_format)

def reformat_tensor(t: torch.Tensor) -> torch.Tensor:
match qkv_format:
case QkvFormat.BHSE:
return t
case QkvFormat.BSHE:
return t.transpose(1, 2).contiguous().transpose(1, 2)

attn, q_grad, k_grad, v_grad = fd.execute(
[
reformat_tensor(sharded_q).requires_grad_(),
reformat_tensor(sharded_k).requires_grad_(),
reformat_tensor(sharded_v).requires_grad_(),
reformat_tensor(sharded_out_grad),
]
)

def assert_close(actual, expected):
match qkv_format:
case QkvFormat.BHSE:
assert actual.is_contiguous()
case QkvFormat.BSHE:
assert actual.transpose(1, 2).is_contiguous()

# Use the default rtol for bfloat16 and a relaxed atol.
torch.testing.assert_close(
actual,
multidevice_test.shard_tensor(expected, 1, mesh),
rtol=1.6e-2,
atol=1e-2,
)

for actual, expected in zip(
[attn, q_grad, k_grad, v_grad],
[expected_out, expected_q_grad, expected_k_grad, expected_v_grad],
):
assert_close(actual, expected)


def get_benchmark_fn(func, /, profile: bool):
def wrapper(*args, **kwargs):
if profile:
Expand Down

0 comments on commit 7e3001b

Please sign in to comment.