Skip to content

Commit

Permalink
test with stream parallel type and host IR
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Jan 16, 2025
1 parent a2b1650 commit e037ee5
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 18 deletions.
34 changes: 19 additions & 15 deletions bench/test
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
#!/bin/bash
EXPERIMENT=Dummy_profile_msgsize256m_float16_M128k_K128k_N32k_UCC_IB
EXPERIMENT=StreamParallelType_tests
DATE=$(date +%Y%m%d-%H%M)
LOG_BASE="/opt/pytorch/Fuser/bench/logs"

NP=8
BACKEND=UCC
M=131072 #32768
K=131072
N=32768 #1024
M=32768
K=32768
N=1024

S=8
Streams=8
Streams=3
Pgs=1

L=1048576 #268435456 #67108864 #131072
PRE_COMM="_pre_comm"
# M=131072 #32768
# K=131072
# N=32768 #1024
# L=1048576 #268435456 #67108864 #131072
# PRE_COMM="_pre_comm"
# POST_COMM="_post_comm"
# UNFUSE="_unfused"
# GRAPH="_WithCudaGraph"
# cuStreamWrite=WithcuStreamWriteValue32_
# GTEST_PREFIX="OverlapBenchmark.PipelinedAGMatmulBenchmark/"
GTEST_PREFIX="DummyOverlapBenchmark.PipelinedAGMatmulBenchmark/"
# GTEST_POSTFIX="${BACKEND}_S${S}_M${M}_K${K}_N${N}_Streams${Streams}_${cuStreamWrite}Pgs${Pgs}${UNFUSE}${GRAPH}"
GTEST_POSTFIX="${BACKEND}_M${M}_K${K}_N${N}_L${L}${PRE_COMM}${POST_COMM}"
# GTEST_PREFIX="DummyOverlapBenchmark.PipelinedAGMatmulBenchmark/"
GTEST_PREFIX="OverlapBenchmark.PipelinedAGMatmulBenchmarkStreamParallelType/"
GTEST_POSTFIX="${BACKEND}_S${S}_M${M}_K${K}_N${N}_Streams${Streams}_${cuStreamWrite}Pgs${Pgs}${UNFUSE}${GRAPH}"
# GTEST_POSTFIX="${BACKEND}_M${M}_K${K}_N${N}_L${L}${PRE_COMM}${POST_COMM}"
export GTEST_FILTER="${GTEST_PREFIX}${GTEST_POSTFIX}"
echo "gtest filter: $GTEST_FILTER" | tee -a $LOG_FILE_INFO

Expand All @@ -32,7 +36,7 @@ MPIFLAGS=" -np $NP"
# MPIFLAGS+=" -x NCCL_DEBUG=TRACE" #INFO
# MPIFLAGS+=" -x NCCL_MAX_NCHANNELS=1"

# MPIFLAGS+=" -x UCC_CL_BASIC_TLS=nccl"
MPIFLAGS+=" -x UCC_CL_BASIC_TLS=nccl"
# MPIFLAGS+=" -x UCC_TL_NCCL_SYNC=event"

# MPIFLAGS+=" -x UCC_CL_BASIC_TLS=cuda"
Expand All @@ -47,15 +51,15 @@ MPIFLAGS=" -np $NP"
# MPIFLAGS+=" -x UCC_EC_CUDA_EXEC_COPY_LARGE_THRESH=1M"
# MPIFLAGS+=" -x UCC_EC_CUDA_EXEC_NUM_THREADS=512"

MPIFLAGS+=" -x UCC_CL_BASIC_TLS=ucp"
MPIFLAGS+=" -x UCX_RNDV_THRESH=0 -x UCX_TLS=ib,cuda_copy"
# MPIFLAGS+=" -x UCC_CL_BASIC_TLS=ucp"
# MPIFLAGS+=" -x UCX_RNDV_THRESH=0 -x UCX_TLS=ib,cuda_copy"
# MPIFLAGS+=" -x UCX_RNDV_SCHEME=put_zcopy"
MPIFLAGS+=" -x UCX_RNDV_SCHEME=get_zcopy"
# MPIFLAGS+=" -x UCX_RNDV_SCHEME=get_zcopy"


MPIFLAGS+=" -x UCX_NET_DEVICES=mlx5_0:1"
# MPIFLAGS+=" -x UCC_CL_BASIC_TLS=^sharp,mlx5"
# MPIFLAGS+=" -x UCC_COLL_TRACE=info"
MPIFLAGS+=" -x UCC_COLL_TRACE=info"
# MPIFLAGS+=" -x UCC_LOG_LEVEL=debug"
# MPIFLAGS+=" -x TORCH_NCCL_AVOID_RECORD_STREAMS=1"
# MPIFLAGS+=" -x CUDA_DEVICE_MAX_CONNECTIONS=2"
Expand Down
111 changes: 108 additions & 3 deletions tests/cpp/test_multidevice_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,15 +345,120 @@ TEST_P(OverlapBenchmark, PipelinedAGMatmulBenchmark) {
}
}

TEST_P(OverlapBenchmark, PipelinedAGMatmulBenchmarkStreamParallelType) {
constexpr int64_t number_of_warmups = 50;
constexpr int64_t number_of_iterations = 200;
constexpr int64_t iteration_profiler_start = 10;
constexpr int64_t iteration_profiler_end = 15;

const int64_t D = communicator_->size();
auto [backend,
S,
M,
K,
N,
number_of_streams,
add_cuStreamWriteValue32,
number_of_pgs,
unfuse_loops,
use_cuda_graph] = GetParam();

if (M % (D * S) != 0) {
GTEST_SKIP() << "M must be a multiple of D * S, but got M = " << M
<< ", D = " << D << ", S = " << S;
}
if (add_cuStreamWriteValue32) {
GTEST_SKIP() << "cuStreamWriteValue32 not supported with StreamParallelType";
}
if (number_of_pgs > 1) {
GTEST_SKIP() << "StreamParallelType not supported with multiple process groups";
}
if (unfuse_loops) {
GTEST_SKIP() << "StreamParallelType not supported with unfused loops";
}
if (use_cuda_graph) {
GTEST_SKIP() << "StreamParallelType not supported with cuda graphs";
}


auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*D), K]
TensorView* b = makeContigTensor(2); //[K, N]
TensorView* c = matmul(a, b); //[S, D, M/(S*D), N]

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

auto mesh = DeviceMesh::createForNumDevices(D);
a->setDeviceMesh(mesh);
b->setDeviceMesh(mesh);
c->setDeviceMesh(mesh);

a->axis(1)->parallelize(ParallelType::DIDx);
c->axis(0)->parallelize(ParallelType::Stream);

communicator_->setDefaultBackend(backend);

hir::HostIrEvaluatorParams params;
params.number_of_streams = number_of_streams;
MultiDeviceExecutor executor(std::move(fusion), *communicator_, params);


auto tensor_options =
at::TensorOptions().dtype(at::kFloat).device(communicator_->device());
at::Tensor ta_unsharded = at::randn({S, D, M / (S * D), K}, tensor_options);
at::Tensor ta = ta_unsharded.slice(
1, communicator_->deviceId(), communicator_->deviceId() + 1);
at::Tensor tb = at::randn({K, N}, tensor_options);
at::Tensor tc_ref = at::matmul(ta_unsharded, tb);

std::vector<c10::IValue> inputs = {ta, tb};
at::Tensor tc;

cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);

for (const auto& iteration :
c10::irange(number_of_warmups + number_of_iterations)) {
if (iteration == iteration_profiler_start) {
cudaProfilerStart();;
}
if (iteration == number_of_warmups) {
cudaEventRecord(start);
}

tc = executor.runWithInput(inputs).at(0);

if (iteration == iteration_profiler_end) {
cudaProfilerStop();;
}
}
cudaEventRecord(stop);
cudaEventSynchronize(stop);
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
milliseconds /= number_of_iterations;

std::string test_name = ::testing::UnitTest::GetInstance()->current_test_info()->name();
times.insert({test_name, milliseconds});
std::cout << "rank " << communicator_->deviceId() << ", " << test_name << " : " << milliseconds << std::endl;

EXPECT_TRUE(torch::allclose(tc_ref, tc, 1e-1, 1e-1));
}

INSTANTIATE_TEST_SUITE_P(
,
OverlapBenchmark,
testing::Combine(
testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
/*S=*/testing::Values(1,2,4,8, 16, 32),
/*M=*/testing::Values(pow(2,10), pow(2,15)),
/*K=*/testing::Values(pow(2,10), pow(2,15)),
/*N=*/testing::Values(pow(2,10)),
/*M=*/testing::Values(pow(2,10), pow(2,15), pow(2,18)),
/*K=*/testing::Values(pow(2,10), pow(2,15), pow(2,18)),
/*N=*/testing::Values(pow(2,10), pow(2,15)),
/*number_of_streams=*/testing::Values(3, 8, 32),
/*add_cuStreamWriteValue32*/testing::Values(false, true),
/*number_of_pgs=*/testing::Values(1, 2, 4, 8),
Expand Down

0 comments on commit e037ee5

Please sign in to comment.