Skip to content

Commit

Permalink
Wrap DistributedTensor as a torch.Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Jan 19, 2025
1 parent ae7a84e commit c58c2b4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 15 deletions.
2 changes: 1 addition & 1 deletion csrc/python_frontend/multidevice_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void bindDeviceMesh(py::module& nvfuser) {

void bindDistributedTensor(py::module& nvfuser) {
py::class_<DistributedTensor> distributed_tensor(
nvfuser, "DistributedTensor");
nvfuser, "_DistributedTensor");
distributed_tensor.def(
"local", &DistributedTensor::local, "Returns the local torch.Tensor.");
distributed_tensor.def(
Expand Down
49 changes: 45 additions & 4 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings

import torch
from torch.utils._pytree import tree_map

# This is needed when libnvfuser.so is patched and doesn't have the pytorch library location available.
pytorch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib")
Expand Down Expand Up @@ -50,6 +51,43 @@ def disable_automatic_serialization():
atexit.unregister(_C.serialize)


class DistributedTensor(torch.Tensor):
_dtensor: _C._DistributedTensor

@staticmethod
def __new__(cls, dtensor: _C._DistributedTensor):
t = dtensor.local()
return torch.Tensor._make_wrapper_subclass(
cls,
t.shape,
strides=t.stride(),
storage_offset=t.storage_offset(),
device=t.device,
layout=t.layout,
requires_grad=t.requires_grad,
dtype=t.dtype,
)

def __init__(self, dtensor: _C._DistributedTensor):
self._dtensor = dtensor

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
def unwrap(t):
if isinstance(t, DistributedTensor):
return t._dtensor.local()
return t

return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

@property
def mesh(self) -> DeviceMesh:
return self._dtensor.mesh()

def axis_sharded_on(self, parallel_type: ParallelType) -> int:
return self._dtensor.axis_sharded_on(parallel_type)


class FusionDefinition(_C._FusionDefinition):
def __init__(self, id=None, max_length=1024):
super(FusionDefinition, self).__init__(id, max_length)
Expand Down Expand Up @@ -198,7 +236,7 @@ def execute(
save_repro_inputs=False,
_enable_options: list[str] = [],
_disable_options: list[str] = [],
) -> list[torch.Tensor | _C.DistributedTensor]:
) -> list[torch.Tensor]:
"""
Executes an nvFuser set of kernels for a given Fusion
Expand Down Expand Up @@ -314,7 +352,7 @@ def execute(
"Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options."
)

out_tensors = self._execute(
out_dtensors: Iterable[_C.DistributedTensor] = self._execute(
inputs,
device=device,
override_user_schedule=override_user_schedule,
Expand All @@ -323,9 +361,12 @@ def execute(
_enable_options=_enable_options,
_disable_options=_disable_options,
)
for i, out_dtensor in enumerate(out_tensors):
out_tensors = []
for out_dtensor in out_dtensors:
if out_dtensor.mesh().size() == 0:
out_tensors[i] = out_dtensor.local()
out_tensors.append(out_dtensor.local())
else:
out_tensors.append(DistributedTensor(out_dtensor))
return out_tensors
except Exception as err:
logger.exception(self._repro_error_str("executing", inputs))
Expand Down
17 changes: 7 additions & 10 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def multidevice_schedule(self):
sharded_input = multidevice_test.shard_tensor(unsharded_input, 0, mesh)

fd = Model()
outputs: list[DistributedTensor] = fd.execute([sharded_input])
torch.testing.assert_close(outputs[0].local().cpu(), unsharded_input.relu() * 2)
outputs = fd.execute([sharded_input])
torch.testing.assert_close(outputs[0].cpu(), unsharded_input.relu() * 2)
assert outputs[0].axis_sharded_on(nvfuser.ParallelType.mesh_x) == -1


Expand Down Expand Up @@ -109,7 +109,7 @@ def multidevice_schedule(self):
# rtol is the same as the default for fp32. atol is slightly increased.
assert out_tensors[0].axis_sharded_on(nvfuser.ParallelType.mesh_x) == 0
torch.testing.assert_close(
out_tensors[0].local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3
out_tensors[0], expected_out_tensor, rtol=1.3e-6, atol=1e-3
)


Expand Down Expand Up @@ -171,7 +171,7 @@ def multidevice_schedule(self):
expected_out_tensor = multidevice_test.shard_tensor(unsharded_out_tensor, -1, mesh)
# rtol is the same as the default for fp32. atol is slightly increased.
torch.testing.assert_close(
out_tensors[0].local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3
out_tensors[0], expected_out_tensor, rtol=1.3e-6, atol=1e-3
)


Expand Down Expand Up @@ -222,9 +222,7 @@ def multidevice_schedule(self) -> None:
(in_grad,) = fd.execute([out_grad.cuda(), weight.cuda()])
# Use the default rtol for half because the output, although being float32,
# is a straight cast from half.
torch.testing.assert_close(
in_grad.local().cpu(), expected_in_grad, rtol=1e-3, atol=1e-2
)
torch.testing.assert_close(in_grad.cpu(), expected_in_grad, rtol=1e-3, atol=1e-2)


class QkvFormat(Enum):
Expand Down Expand Up @@ -339,7 +337,6 @@ def head_parallelize(t: torch.Tensor) -> torch.Tensor:
out, q_grad, k_grad, v_grad = outs

def assert_close(actual, expected):
actual = actual.local()
match qkv_format:
case QkvFormat.BHSE:
assert actual.is_contiguous()
Expand Down Expand Up @@ -751,8 +748,8 @@ def multidevice_schedule(self):
def _assert_shape_dtype(
t: DistributedTensor, expected_sizes: list[int], expected_dtype: torch.dtype
) -> None:
assert t.local().shape == torch.Size(expected_sizes)
assert t.local().dtype == expected_dtype
assert t.shape == torch.Size(expected_sizes)
assert t.dtype == expected_dtype


@pytest.mark.skipif(
Expand Down

0 comments on commit c58c2b4

Please sign in to comment.