diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py index ec8c9c5a29c..14cf7104cf9 100644 --- a/tests/python/test_multidevice.py +++ b/tests/python/test_multidevice.py @@ -407,6 +407,145 @@ def assert_close(actual, expected): assert_close(v_grad, head_parallelize(expected_v_grad)) +@pytest.mark.skipif( + utils.is_pre_ampere(), + reason="Flash Attention is only supported on Ampere and newer devices.", +) +@pytest.mark.parametrize("qkv_format", [QkvFormat.BHSE, QkvFormat.BSHE]) +@pytest.mark.mpi +def test_sdpa_loop_split(multidevice_test, qkv_format: QkvFormat): + d, b, s, h, e = multidevice_test.size, 2, 1024, 12, 768 + + if h % d != 0: + pytest.skip(f"We only support even split, so {h} has to be divisible by {d}.") + mesh = nvfuser.DeviceMesh(range(d)) + + class Model(FusionDefinition): + def __init__(self, qkv_format: QkvFormat): + super().__init__() + self._qkv_format = qkv_format + + def definition(self) -> None: + match self._qkv_format: + case QkvFormat.BHSE: + stride_order = [3, 2, 1, 0] + case QkvFormat.BSHE: + stride_order = [3, 1, 2, 0] + + self.q, self.k, self.v, self.out_grad = [ + self.define_tensor( + shape=[b, h, s, e // h], + dtype=DataType.BFloat16, + stride_order=stride_order, + ) + for _ in range(4) + ] + + # TODO(#3123): support sharded dropout and change this to a + # positive probability. + dropout_p = self.define_scalar(0.0, dtype=DataType.Double) + is_causal = self.define_scalar(True, dtype=DataType.Bool) + self.attn, self.log_sumexp, seed, offset = self.ops.sdpfa_fwd( + self.q, self.k, self.v, dropout_p, is_causal, scale=None + ) + + self.q_grad, self.k_grad, self.v_grad = self.ops.sdpfa_bwd( + self.out_grad, + self.q, + self.k, + self.v, + self.attn, + self.log_sumexp, + dropout_p, + is_causal, + seed, + offset, + scale=None, + ) + + self.add_output(self.attn) + for grad in [self.q_grad, self.k_grad, self.v_grad]: + self.add_output(grad) + + def multidevice_schedule(self) -> None: + for t in [ + self.q, + self.k, + self.v, + self.attn, + self.log_sumexp, + self.out_grad, + self.q_grad, + self.k_grad, + self.v_grad, + ]: + self.sched._set_device_mesh(t, mesh) + self.sched.split(t, 1, d, False) + self.sched.parallelize(t, 1, nvfuser.ParallelType.mesh_x) + if self._qkv_format == QkvFormat.BSHE: + # The loop domain is: {i{B}, i{DIDx}, i{H//D}, i{S}, i{E//H}} + # Reorder i{S} in the allocation domain for BHSE: {i{DIDx}, i{B}, i{S}, i{H//D}, i{E//H}} + self.sched.reorder(t, {2: 3, 3: 2}) + self.sched.set_allocation_as_loop(t) + + torch.cuda.set_device(multidevice_test.local_rank) + + def make_unsharded_tensor() -> torch.Tensor: + return torch.randn(b, h, s, e // h, dtype=torch.bfloat16, device="cpu") + + q, k, v = [make_unsharded_tensor().requires_grad_() for _ in range(3)] + out_grad = make_unsharded_tensor() + sharded_q, sharded_k, sharded_v, sharded_out_grad = [ + multidevice_test.shard_tensor(t, 1, mesh) for t in [q, k, v, out_grad] + ] + + with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + expected_out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=0.0, is_causal=True, scale=None + ) + expected_out.backward(out_grad) + expected_q_grad, expected_k_grad, expected_v_grad = q.grad, k.grad, v.grad + + fd = Model(qkv_format) + + def reformat_tensor(t: torch.Tensor) -> torch.Tensor: + match qkv_format: + case QkvFormat.BHSE: + return t + case QkvFormat.BSHE: + return t.transpose(1, 2).contiguous().transpose(1, 2) + + attn, q_grad, k_grad, v_grad = fd.execute( + [ + reformat_tensor(sharded_q).requires_grad_(), + reformat_tensor(sharded_k).requires_grad_(), + reformat_tensor(sharded_v).requires_grad_(), + reformat_tensor(sharded_out_grad), + ] + ) + + def assert_close(actual, expected): + match qkv_format: + case QkvFormat.BHSE: + assert actual.is_contiguous() + case QkvFormat.BSHE: + assert actual.transpose(1, 2).is_contiguous() + + # Use the default rtol for bfloat16 and a relaxed atol. + torch.testing.assert_close( + actual, + multidevice_test.shard_tensor(expected, 1, mesh), + rtol=1.6e-2, + atol=1e-2, + ) + + for actual, expected in zip( + [attn, q_grad, k_grad, v_grad], + [expected_out, expected_q_grad, expected_k_grad, expected_v_grad], + ): + assert_close(actual, expected) + + def get_benchmark_fn(func, /, profile: bool): def wrapper(*args, **kwargs): if profile: