Skip to content

Commit

Permalink
Introduce DistributedTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Jan 19, 2025
1 parent 3b67d72 commit ae7a84e
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 25 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ endif()

if(BUILD_PYTHON)
list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/distributed_tensor.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp
Expand Down
25 changes: 25 additions & 0 deletions csrc/python_frontend/distributed_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <exceptions.h>
#include <python_frontend/distributed_tensor.h>
#include <utils.h>

namespace nvfuser::python_frontend {

void DistributedTensor::setAxisIsShardedOn(
const int64_t axis,
const ParallelType parallel_type) {
const auto i = axis_sharded_on_.find(parallel_type);
NVF_CHECK(
i == axis_sharded_on_.end(),
"Parallel type ",
parallel_type,
" was already used to shard axis ",
i->second);
axis_sharded_on_[parallel_type] = axis;
}

int64_t DistributedTensor::axisShardedOn(
const ParallelType parallel_type) const {
return getOrDefault(axis_sharded_on_, parallel_type, -1L);
}

} // namespace nvfuser::python_frontend
46 changes: 46 additions & 0 deletions csrc/python_frontend/distributed_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <ATen/core/TensorBody.h>

#include <multidevice/device_mesh.h>
#include <type.h>

namespace nvfuser::python_frontend {

class DistributedTensor {
public:
explicit DistributedTensor(
at::Tensor local_tensor,
const DeviceMesh& mesh = DeviceMesh())
: local_(local_tensor), mesh_(mesh) {}
DistributedTensor(const DistributedTensor&) = delete;
DistributedTensor& operator=(const DistributedTensor&) = delete;
DistributedTensor(DistributedTensor&&) = default;
DistributedTensor& operator=(DistributedTensor&&) = default;

const DeviceMesh& mesh() const {
return mesh_;
}

at::Tensor local() const {
return local_;
}

void setAxisIsShardedOn(int64_t axis, ParallelType parallel_type);

int64_t axisShardedOn(ParallelType parallel_type) const;

private:
at::Tensor local_;
DeviceMesh mesh_;
std::unordered_map<ParallelType, int64_t> axis_sharded_on_;
};

} // namespace nvfuser::python_frontend
53 changes: 47 additions & 6 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
#include <debug.h>
#include <fusion_profiler.h>
#include <instrumentation.h>
#include <multidevice/utils.h>
#include <options.h>
#include <preseg_passes/pre_segmenter.h>
#include <python_frontend/distributed_tensor.h>
#include <python_frontend/fusion_cache.h>
#include <python_frontend/fusion_definition.h>
#include <python_frontend/translation.h>
#include <runtime/executor_kernel_arg.h>
#include <runtime/fusion_kernel_runtime.h>
#include <scheduler/compile_time_info.h>
#include <scheduler/scheduler_types.h>
#include <utils.h>
Expand Down Expand Up @@ -344,7 +347,7 @@ void FusionDefinition::print(std::ostream& os) const {
os << std::endl;
}

std::vector<at::Tensor> FusionDefinition::execute(
std::vector<DistributedTensor> FusionDefinition::execute(
const at::ArrayRef<c10::IValue>& inputs,
std::optional<int8_t> selected_device,
bool override_user_schedule,
Expand Down Expand Up @@ -399,9 +402,9 @@ std::vector<at::Tensor> FusionDefinition::execute(
};
const auto* user_sched = find_user_schedule();

std::vector<at::Tensor> outputs;
std::vector<at::Tensor> out_tensors;
if (user_sched == nullptr) {
outputs = scheds->auto_gen_schedules->runFusionWithInputs(
out_tensors = scheds->auto_gen_schedules->runFusionWithInputs(
inputs, std::nullopt, selected_device);
} else {
if (isProfilerEnabledWithCupti()) {
Expand All @@ -417,7 +420,7 @@ std::vector<at::Tensor> FusionDefinition::execute(
user_sched->executor->compile(
user_sched->scheduled_fusion.get(), inputs);
}
outputs = user_sched->executor->run(inputs);
out_tensors = user_sched->executor->run(inputs);
} else {
// Automatic scheduler was used for UserSchedule.
// Pass launch and compile params to compileFusion and runFusion.
Expand All @@ -430,7 +433,7 @@ std::vector<at::Tensor> FusionDefinition::execute(
user_sched->heuristic_params->cparams,
user_sched->heuristic_params->scheduler_type);
}
outputs = user_sched->executor->run(
out_tensors = user_sched->executor->run(
inputs,
user_sched->heuristic_params->lparams,
user_sched->heuristic_params->cparams);
Expand All @@ -453,7 +456,45 @@ std::vector<at::Tensor> FusionDefinition::execute(
debug_output_ = debug_ss.str();
}

return outputs;
std::vector<DistributedTensor> out_dtensors;
out_dtensors.reserve(out_tensors.size());
if (user_sched == nullptr) {
FusionKernelRuntime* runtime =
scheds->auto_gen_schedules->getMostRecentKernelRuntime();
NVF_ERROR(runtime != nullptr);
Fusion* fusion = runtime->fusionSegments()->completeFusion();
NVF_ERROR(fusion != nullptr);

int64_t i = 0;
for (Val* out_val : fusion->outputs()) {
auto* out_tv = out_val->as<TensorView>();
if (fusion->getOutputAlias(out_tv).hide_output) {
continue;
}

const at::Tensor& out_tensor = out_tensors.at(i);
i++;
const DeviceMesh& mesh = out_tv->getDeviceMesh();
DistributedTensor out_dtensor(out_tensor, mesh);

if (mesh.size() > 0) {
for (const ParallelType parallel_type : kParallelTypeDIDs) {
if (const auto axis = getShardedLogicalAxis(out_tv, parallel_type);
axis != -1) {
out_dtensor.setAxisIsShardedOn(axis, parallel_type);
}
}
}

out_dtensors.push_back(std::move(out_dtensor));
}
NVF_ERROR(out_dtensors.size() == out_tensors.size());
} else {
for (const auto& out_tensor : out_tensors) {
out_dtensors.emplace_back(out_tensor);
}
}
return out_dtensors;
}

std::string FusionDefinition::fusionIr() {
Expand Down
5 changes: 3 additions & 2 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
// clang-format on
#pragma once

#include <exceptions.h>
#include <functional>
#include <iostream>
#include <unordered_map>

#include <exceptions.h>
#include <python_frontend/distributed_tensor.h>
#include <python_frontend/fusion_state.h>
#include <python_frontend/segmentation.h>
#include <visibility.h>
Expand Down Expand Up @@ -193,7 +194,7 @@ class NVF_API FusionDefinition : public FusionState {
//! Prints a python function representing the definition
NVF_API void print(std::ostream& os) const;
//! Executes a fusion if a valid definition or cache lookup occurred prior
NVF_API std::vector<at::Tensor> execute(
NVF_API std::vector<DistributedTensor> execute(
const at::ArrayRef<c10::IValue>& inputs,
std::optional<int8_t> device,
bool override_user_schedule,
Expand Down
18 changes: 18 additions & 0 deletions csrc/python_frontend/multidevice_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,29 @@ void bindDeviceMesh(py::module& nvfuser) {
py::arg("device_id"));
}

void bindDistributedTensor(py::module& nvfuser) {
py::class_<DistributedTensor> distributed_tensor(
nvfuser, "DistributedTensor");
distributed_tensor.def(
"local", &DistributedTensor::local, "Returns the local torch.Tensor.");
distributed_tensor.def(
"mesh",
&DistributedTensor::mesh,
"Returns the device mesh.",
py::return_value_policy::reference);
distributed_tensor.def(
"axis_sharded_on",
&DistributedTensor::axisShardedOn,
"Returns the axis sharded on the given parallel type. If the distributed tensor is replicated on that parallel type, returns -1.",
py::arg("parallel_type"));
}

} // namespace

void bindMultidevice(py::module& nvfuser) {
bindCommunicator(nvfuser);
bindDeviceMesh(nvfuser);
bindDistributedTensor(nvfuser);
}

} // namespace nvfuser::python_frontend
7 changes: 5 additions & 2 deletions csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,12 @@ NVF_API char* getNvFuserEnv(const char* env_name);

// Returns the mapped value or the default.
template <typename K, typename V>
V getOrDefault(const std::unordered_map<K, V>& map, const K& key) {
const V& getOrDefault(
const std::unordered_map<K, V>& map,
const K& key,
const V& default_value = V()) {
const auto i = map.find(key);
return i == map.end() ? V() : i->second;
return i == map.end() ? default_value : i->second;
}

size_t deviceAvailableSharedMemoryBytes();
Expand Down
11 changes: 6 additions & 5 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def execute(
save_repro_inputs=False,
_enable_options: list[str] = [],
_disable_options: list[str] = [],
):
) -> list[torch.Tensor | _C.DistributedTensor]:
"""
Executes an nvFuser set of kernels for a given Fusion
Expand Down Expand Up @@ -306,8 +306,6 @@ def execute(
if hasattr(self, "segments") and len(self.segments) > 0:
return self._execute_segments(inputs, device=device, profile=profile)

results = None

try:
if print_repro:
print(self.repro_script_for(inputs))
Expand All @@ -316,7 +314,7 @@ def execute(
"Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options."
)

results = self._execute(
out_tensors = self._execute(
inputs,
device=device,
override_user_schedule=override_user_schedule,
Expand All @@ -325,7 +323,10 @@ def execute(
_enable_options=_enable_options,
_disable_options=_disable_options,
)
return results
for i, out_dtensor in enumerate(out_tensors):
if out_dtensor.mesh().size() == 0:
out_tensors[i] = out_dtensor.local()
return out_tensors
except Exception as err:
logger.exception(self._repro_error_str("executing", inputs))
raise
Expand Down
2 changes: 1 addition & 1 deletion tests/python/multidevice_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def multidevice_test():
fixture = MultideviceTest()
yield fixture
# Sync all ranks after each test for isolation.
fixture.communicator.barrier()
# fixture.communicator.barrier()
23 changes: 14 additions & 9 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import multidevice_fixtures
import nvfuser
import utils
from nvfuser import DataType, FusionDefinition
from nvfuser import DataType, FusionDefinition, DistributedTensor


multidevice_test = multidevice_fixtures.multidevice_test
Expand Down Expand Up @@ -54,8 +54,9 @@ def multidevice_schedule(self):
sharded_input = multidevice_test.shard_tensor(unsharded_input, 0, mesh)

fd = Model()
outputs = fd.execute([sharded_input])
torch.testing.assert_close(outputs[0].cpu(), unsharded_input.relu() * 2)
outputs: list[DistributedTensor] = fd.execute([sharded_input])
torch.testing.assert_close(outputs[0].local().cpu(), unsharded_input.relu() * 2)
assert outputs[0].axis_sharded_on(nvfuser.ParallelType.mesh_x) == -1


@pytest.mark.mpi
Expand Down Expand Up @@ -106,8 +107,9 @@ def multidevice_schedule(self):
rank : rank + 1
]
# rtol is the same as the default for fp32. atol is slightly increased.
assert out_tensors[0].axis_sharded_on(nvfuser.ParallelType.mesh_x) == 0
torch.testing.assert_close(
out_tensors[0], expected_out_tensor, rtol=1.3e-6, atol=1e-3
out_tensors[0].local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3
)


Expand Down Expand Up @@ -169,7 +171,7 @@ def multidevice_schedule(self):
expected_out_tensor = multidevice_test.shard_tensor(unsharded_out_tensor, -1, mesh)
# rtol is the same as the default for fp32. atol is slightly increased.
torch.testing.assert_close(
out_tensors[0], expected_out_tensor, rtol=1.3e-6, atol=1e-3
out_tensors[0].local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3
)


Expand Down Expand Up @@ -220,7 +222,9 @@ def multidevice_schedule(self) -> None:
(in_grad,) = fd.execute([out_grad.cuda(), weight.cuda()])
# Use the default rtol for half because the output, although being float32,
# is a straight cast from half.
torch.testing.assert_close(in_grad.cpu(), expected_in_grad, rtol=1e-3, atol=1e-2)
torch.testing.assert_close(
in_grad.local().cpu(), expected_in_grad, rtol=1e-3, atol=1e-2
)


class QkvFormat(Enum):
Expand Down Expand Up @@ -335,6 +339,7 @@ def head_parallelize(t: torch.Tensor) -> torch.Tensor:
out, q_grad, k_grad, v_grad = outs

def assert_close(actual, expected):
actual = actual.local()
match qkv_format:
case QkvFormat.BHSE:
assert actual.is_contiguous()
Expand Down Expand Up @@ -744,10 +749,10 @@ def multidevice_schedule(self):
# TODO(#2962): validate the numbers as well. Currently, the numbers are off
# by a lot, making comparison infeasible.
def _assert_shape_dtype(
t: torch.Tensor, expected_sizes: list[int], expected_dtype: torch.dtype
t: DistributedTensor, expected_sizes: list[int], expected_dtype: torch.dtype
) -> None:
assert t.shape == torch.Size(expected_sizes)
assert t.dtype == expected_dtype
assert t.local().shape == torch.Size(expected_sizes)
assert t.local().dtype == expected_dtype


@pytest.mark.skipif(
Expand Down

0 comments on commit ae7a84e

Please sign in to comment.