From f097e58ea1ae9d2c70115ce6dfabc7b27763fcb3 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:28:33 -0700 Subject: [PATCH 01/17] Presegmentation pass to force segment breaks when inplace update can cause RW race (#2999) Issue #2664 A RW race can occur when an intermediate tensorview is aliased to a fusion input and the intermediate tensorview or the aliased input is a producer/consumer of a broadcast op. This presegmentation pass traverses the fusion to find such inplace updates, and inserts `segmet_set + set` to force the inplace update into a separate copy kernel. This ensures that the write to the fusion input only occurs when all the reads of that fusion input have concluded. --------- Co-authored-by: jjsjann123 --- CMakeLists.txt | 1 + csrc/preseg_passes/pre_segmenter.cpp | 2 + csrc/preseg_passes/segment_inplace_update.cpp | 156 +++++++++++++++ csrc/preseg_passes/segment_inplace_update.h | 27 +++ tests/cpp/test_alias.cpp | 40 +++- tests/python/test_pointwise.py | 178 ++++++++++++++++++ 6 files changed, 403 insertions(+), 1 deletion(-) create mode 100644 csrc/preseg_passes/segment_inplace_update.cpp create mode 100644 csrc/preseg_passes/segment_inplace_update.h diff --git a/CMakeLists.txt b/CMakeLists.txt index abb96e2e596..a6eb9fed679 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,6 +195,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_bcast_squeeze.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp + ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp ${NVFUSER_SRCS_DIR}/rng.cpp ${NVFUSER_SRCS_DIR}/runtime/allocations.cpp ${NVFUSER_SRCS_DIR}/runtime/executor.cpp diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index e63ab07ffff..2ad82f9dd20 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -24,6 +24,7 @@ #include #include #include +#include namespace nvfuser::preseg_passes { @@ -65,6 +66,7 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); } } // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/segment_inplace_update.cpp b/csrc/preseg_passes/segment_inplace_update.cpp new file mode 100644 index 00000000000..73fd24ad8f0 --- /dev/null +++ b/csrc/preseg_passes/segment_inplace_update.cpp @@ -0,0 +1,156 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace nvfuser::preseg_passes { +// When an intermediate tensorview is aliased to a fusion input, +// a RW race occurs, when the intermediate tensorview or +// the aliased input is in path of a broadcast. +// This preseg pass : +// 1. Finds any tensorviews used in inplace updates (AllocationType:ReuseBuffer) +// in the fusion +// 2. Traverses the fusion graph starting from broadcast ops and stores all +// direct/indirect producer/consumer tensorviews. +// 3. For all aliased tensorviews, if the aliased tensorview or the aliased +// input is present in the set of visited tensorviews in step 2, we insert a +// segment set and set to force a separate copy kernel. Additionally, +// we check for implict broadcasts if any aliased input already has a +// broadcast dimension that is concretized later in the fusion. This ensures +// that all write operations to the fusion inputs occur after the read +// operations have completed. See Issue #2664: https:// +// github.com/NVIDIA/Fuser/issues/2664 +namespace { +void insertSegmentSet(Fusion* fusion) { + std::vector aliased_tvs; + + // Find all tensorviews which are used in inplace updates. + // Aliases will always be fusion outputs. + for (Val* out : fusion->outputs()) { + if (fusion->getOutputAlias(out->as()).type == + AllocationType::ReuseBuffer) { + aliased_tvs.push_back(out->as()); + } + } + + // Return early if there is no inplace update + if (aliased_tvs.empty()) { + return; + } + + // fusion->exprs() is a topologically sorted list. Filter out the broadcast + // ops from the list. + auto all_exprs = fusion->exprs(); + auto all_bcast_ops = ir_utils::filterByType(all_exprs); + + // Traverse and store all direct/indirect consumer tensorviews of these + // broadcast nodes. If the tensorview has been visited, return --> this means + // that we have already traversed that branch + std::unordered_set visited_tvs; + for (auto bcast_op : all_bcast_ops) { + std::deque tvs_to_visit; + tvs_to_visit.push_back(bcast_op->output(0)->as()); + while (!tvs_to_visit.empty()) { + TensorView* current_tv = tvs_to_visit.front(); + tvs_to_visit.pop_front(); + if (visited_tvs.count(current_tv)) { + continue; + } + visited_tvs.insert(current_tv); + std::vector current_tv_uses = current_tv->uses(); + for (Expr* use : current_tv_uses) { + for (auto output_tv : + ir_utils::filterByType(use->outputs())) { + tvs_to_visit.push_back(output_tv->as()); + } + } + } + } + + // Traverse and store the direct/indirect producer tensorviews of these + // broadcast nodes If that tensorview has been visited, return. + for (auto bcast_op : all_bcast_ops) { + std::deque tvs_to_visit; + tvs_to_visit.push_back(bcast_op->input(0)->as()); + while (!tvs_to_visit.empty()) { + TensorView* current_tv = tvs_to_visit.front(); + tvs_to_visit.pop_front(); + if (visited_tvs.count(current_tv)) { + continue; + } + visited_tvs.insert(current_tv); + auto definition = current_tv->definition(); + if (definition != nullptr) { + for (auto input_tv : + ir_utils::filterByType(definition->inputs())) { + tvs_to_visit.push_back(input_tv->as()); + } + } + } + } + + // Use permissive IdModel graph to identify any concretized broadcast + // iterdomain in any aliased input. + auto id_model = IdModel(fusion, /*build_graphs=*/false); + id_model.buildPermissiveGraph(); + const ValGraph& permissive_graph = + id_model.idGraph(IdMappingMode::PERMISSIVE); + + auto hasConcretizedBroadcast = [&](TensorView* tv) -> bool { + if (!tv->hasBroadcast()) { + return false; + } + for (IterDomain* id : tv->getLogicalDomain()) { + if (!id->isBroadcast()) { + continue; + } + if (!permissive_graph.hasGroup(id)) { + continue; + } + const ValGroup& val_group = permissive_graph.toGroup(id); + for (auto other_id : val_group.get()->vector()) { + if (!other_id->as()->isBroadcast()) { + return true; + } + } + } + return false; + }; + + // For all aliased tensorviews: + // 1) if that tv or the corresponding aliased input is a producer/consumer of + // a broadcast op, or 2) the aliased input has a concretized broadcast, insert + // a (segment_set + set) to force the inplace update into a separate copy + // kernel. NOTE: We cannot use a segment_set alone. Since, there will be no + // data flow across this segment_set (the output of segment_set is an output + // of given fusion with no uses), it will be merged with other segments. + // https://github.com/NVIDIA/Fuser/blob/92b635125ae509cc6b2ccbe29e957586a9cbb059/csrc/fusion_segmenter.cpp#L2331-L2346 + for (auto aliased_tv : aliased_tvs) { + TensorView* aliased_input = + fusion->getOutputAlias(aliased_tv).aliased_io->as(); + if (visited_tvs.count(aliased_tv) || visited_tvs.count(aliased_input) || + hasConcretizedBroadcast(aliased_input)) { + TensorView* alias_seg = segment_set(aliased_tv); + TensorView* alias_copy = set(alias_seg); + fusion->replaceOutput(aliased_tv, alias_copy); + } + } +} +} // namespace + +void SegmentInplaceUpdatePass::runPass(Fusion* fusion) { + insertSegmentSet(fusion); +} +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/segment_inplace_update.h b/csrc/preseg_passes/segment_inplace_update.h new file mode 100644 index 00000000000..5966e714e25 --- /dev/null +++ b/csrc/preseg_passes/segment_inplace_update.h @@ -0,0 +1,27 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser::preseg_passes { + +//! RemoveEmptyPass removes intermediate empty tensors (those with at least one +//! extent zero thar are neither a fusion output or input). +class SegmentInplaceUpdatePass + : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static std::string name() { + return "SegmentInplaceUpdate"; + } +}; + +} // namespace nvfuser::preseg_passes diff --git a/tests/cpp/test_alias.cpp b/tests/cpp/test_alias.cpp index 893fac13050..68337688656 100644 --- a/tests/cpp/test_alias.cpp +++ b/tests/cpp/test_alias.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -1016,8 +1017,11 @@ TEST_F(AliasTest, ReuseBuffer_AliasAcrossSegments) { testValidate( fec.fusion(), outputs, {original_t0, t1, t2}, __LINE__, __FILE__); + // https://github.com/NVIDIA/Fuser/pull/2999 will cause 3 segments instead of + // the optimal 2 segments. Change back to 2 segments once + // https://github.com/NVIDIA/Fuser/issues/3251 is resolved. EXPECT_EQ( - fec.getMostRecentKernelRuntime()->fusionSegments()->groups().size(), 2) + fec.getMostRecentKernelRuntime()->fusionSegments()->groups().size(), 3) << "segmentation didn't happen as expected"; auto t3 = original_t0.add(1.0); @@ -1426,4 +1430,38 @@ TEST_F(AliasTest, Bookend_Issue2375) { HeuristicIs(SchedulerType::InnerPersistent))); } +// Repro for https://github.com/NVIDIA/Fuser/issues/2664 +TEST_F(AliasTest, Issue2664) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int64_t n = 4194304; + const DataType dtype = DataType::Float; + const std::vector input_shape = {n}; + + auto tv1 = makeContigTensor(1, dtype); + auto tv2 = makeContigTensor(0, dtype); + fusion->addInput(tv1); + fusion->addInput(tv2); + + auto s3 = IrBuilder::create(1.0); + auto tv4 = add(tv2, s3); + auto tv5 = broadcast(tv4, {true}); + auto tv7 = expand(tv5, {tv1->axis(0)->extent()}); + auto tv8 = mul(tv1, tv7); + fusion->aliasOutputToInput(tv4, tv2, AllocationType::ReuseBuffer); + fusion->addOutput(tv8); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto t1 = at::randn(input_shape, options); + auto t2 = at::randn({}, options); + auto aten_out = (t2 + 1.0) * t1; + + FusionExecutorCache fec(std::move(fusion)); + auto out_tensors = fec.runFusionWithInputs({t1, t2}); + testValidate( + fec.fusion(), out_tensors, {t1, t2}, {aten_out}, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/python/test_pointwise.py b/tests/python/test_pointwise.py index 09b3c51d38a..49d23c6981f 100644 --- a/tests/python/test_pointwise.py +++ b/tests/python/test_pointwise.py @@ -157,3 +157,181 @@ def nvfuser_fusion(fd: FusionDefinition) -> None: nvf_out = fd.execute(inputs) torch.testing.assert_close(nvf_out[0], inputs[0] + inputs[2]) torch.testing.assert_close(nvf_out[1], inputs[1] + inputs[2]) + + +# Example 1: Repro from https://github.com/NVIDIA/Fuser/issues/2664 +# T4 (scalar) is broadcasted and used in mul computation. It is also used to update T2 inplace. +# This causes a RW race. In this case, the aliased tensor is a producer of bcast op. +def test_inplace_issue2664(): + def nvfuser_fusion_id0(fd: FusionDefinition) -> None: + T1 = fd.define_tensor( + shape=[-1], + contiguity=[True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[0], + ) + T2 = fd.define_tensor( + shape=[], contiguity=[], dtype=DataType.Float, is_cpu=False + ) + S3 = fd.define_scalar(1.00000, dtype=DataType.Double) + T4 = fd.ops.add(T2, S3) + S5 = fd.define_scalar(4194304, dtype=DataType.Int) + V6 = fd.define_vector([S5], dtype=DataType.Int) + T7 = fd.ops.broadcast_in_dim(T4, shape=V6, broadcast_dims=[]) + T8 = fd.ops.mul(T1, T7) + fd.add_output(T4, T2) + fd.add_output(T8) + + with FusionDefinition() as fd: + nvfuser_fusion_id0(fd) + + inputs = [ + torch.randn((4194304,), dtype=torch.float32, device="cuda:0").as_strided( + (4194304,), (1,) + ), + torch.randn((1,), dtype=torch.float32, device="cuda:0").as_strided((), ()), + ] + # Reference out = T4 (aliased to inputs[-1]), T8 + ref_out = [inputs[-1] + 1.0, (inputs[-1] + 1.0) * inputs[0]] + + out = fd.execute(inputs, profile=True) + + assert fd.profile().segments == 2 + + torch.testing.assert_close(inputs[-1], ref_out[0]) + torch.testing.assert_close(out[0], ref_out[1]) + + +# Example 2 for Issue 2664: +# T2 is broadcasted and used in mul/add compute. It is also summed (T8) and used to inplace update T2. +# In this case, the aliased tensor (T8) is a consumer of the bcast op. +def test_inplace_post_bcast(): + def fusion_func(fd: FusionDefinition) -> None: + T1 = fd.define_tensor( + shape=[-1], + contiguity=[True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[0], + ) + T2 = fd.define_tensor( + shape=[], contiguity=[], dtype=DataType.Float, is_cpu=False + ) + S5 = fd.define_scalar(4194304, dtype=DataType.Int) + V6 = fd.define_vector([S5], dtype=DataType.Int) + T7 = fd.ops.broadcast_in_dim(T2, shape=V6, broadcast_dims=[]) + T8 = fd.ops.sum(T7, dims=[0], keepdim=False) + T9 = fd.ops.mul(T1, T7) + T10 = fd.ops.add(T1, T7) + fd.add_output(T8, T2) + fd.add_output(T9) + fd.add_output(T10) + + with FusionDefinition() as fd: + fusion_func(fd) + + inputs = [ + torch.randn((4194304,), dtype=torch.float32, device="cuda:0").as_strided( + (4194304,), (1,) + ), + torch.randn((1,), dtype=torch.float32, device="cuda:0").as_strided((), ()), + ] + + # Reference out = T8 (aliased to inputs[-1]), T9, T10 + ref_out = [ + inputs[-1] * inputs[0].size(0), + inputs[-1] * inputs[0], + inputs[0] + inputs[1], + ] + + out = fd.execute(inputs, profile=True) + + assert fd.profile().segments == 2 + + torch.testing.assert_close(inputs[-1], ref_out[0]) + torch.testing.assert_close(out[0], ref_out[1]) + torch.testing.assert_close(out[1], ref_out[2]) + + +# Example 3 for Issue 2664: This case involves two inplace updates. +# T7 is aliased to T2: T7 is not a producer/consumer of the bcast op, but the aliased input T2 is a producer of the bcast op. +# T6 is aliased to T3: T6 is a consumer of the bcast op. +def test_multi_inplace(): + def fusion_func(fd: FusionDefinition) -> None: + T1 = fd.define_tensor( + shape=[-1], + contiguity=[True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[0], + ) + T2 = fd.define_tensor( + shape=[], contiguity=[], dtype=DataType.Float, is_cpu=False + ) + T3 = fd.define_tensor( + shape=[], contiguity=[], dtype=DataType.Float, is_cpu=False + ) + T4 = fd.ops.broadcast_in_dim(T2, shape=T1.shape(), broadcast_dims=[]) + T5 = fd.ops.add(T1, T4) + T6 = fd.ops.sum(T5, dims=[0], keepdim=False) + S0 = fd.define_scalar(1.00000, dtype=DataType.Double) + T7 = fd.ops.add(T3, S0) + fd.add_output(T6, T3) + fd.add_output(T7, T2) + + with FusionDefinition() as fd: + fusion_func(fd) + + inputs = [ + torch.randn((4194304,), dtype=torch.float32, device="cuda:0").as_strided( + (4194304,), (1,) + ), + torch.randn((1,), dtype=torch.float32, device="cuda:0").as_strided((), ()), + torch.randn((1,), dtype=torch.float32, device="cuda:0").as_strided((), ()), + ] + + # Reference out = T6 (aliased to inputs[2]), T7 (aliased to inputs[1]) + ref_out = [inputs[-1] + 1.0, (inputs[0] + inputs[1]).sum(dim=-1)] + + fd.execute(inputs, profile=True) + assert fd.profile().segments == 4 + + torch.testing.assert_close(inputs[1], ref_out[0]) + torch.testing.assert_close(inputs[2], ref_out[1]) + + +# Example 4 for Issue 2664: There is no explicit broadcast. However, the aliased input has a broadcast dimension that is concretized in the fusion. +# T0 has a implicit broadcast which is used in add(T3) and neg (T4). T4 is used to inplace update T0, which causes RW race. +def test_implicit_bcast_inplace(): + def fusion_func(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, 1], + contiguity=[True, None], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[1, 0], + ) + T3 = fd.ops.add(T1, T0) + T4 = fd.ops.neg(T0) + fd.add_output(T3) + fd.add_output(T4, T0) + + inputs = [ + torch.randn((4194304, 1), dtype=torch.float32, device="cuda:0"), + torch.randn((4194304, 128), dtype=torch.float32, device="cuda:0"), + ] + with FusionDefinition() as fd: + fusion_func(fd) + ref_out = [inputs[0] + inputs[1], -inputs[0]] + out = fd.execute(inputs) + + torch.testing.assert_close(ref_out[0], out[0]) + torch.testing.assert_close(ref_out[1], inputs[0]) From 0d169fea2a1cdb770a31f9e51edf6a12a0b687ed Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 28 Oct 2024 15:43:33 -0700 Subject: [PATCH 02/17] Fix elect sync predicate (#3295) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes https://github.com/NVIDIA/Fuser/issues/3199 Perf: ```C++ Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 47.8 247326 1 247326.0 247326.0 247326 247326 0.0 ::nvfuser_none_f0_c0_r0_g0(::Tensor<::__half, (int)3, (int)3>, … 17.0 88191 1 88191.0 88191.0 88191 88191 0.0 nvjet_hsh_256x128_64x4_1x2_h_bz_coopA_NTT ``` Perf nvFuser/cuBLAS: `35.6%` Strangely, elect-sync hurt instead of help perf. I need to look into this, but anyway, this PR is a bug fix, not a perf improvement. If elect-sync does not work, we should disable it, instead of enabling it and rely on a bug to avoid it hurting perf.
Generated code ```CUDA __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 3, 3> T0, Tensor<__half, 3, 3> T1, const __grid_constant__ TensorMap var0, const __grid_constant__ TensorMap var1, Tensor<__half, 2, 2> T3) { alignas(16) extern __shared__ char array[]; const unsigned smem_offset = 0; nvfuser_index_t i2; i2 = ceilDiv(T0.logical_size[0LL], 16); nvfuser_index_t i3; i3 = -3 + i2; const TensorMap* ptr4; ptr4 = &var0; nvfuser_index_t i5; i5 = 256 * ((nvfuser_index_t)blockIdx.x); __half* T5 = reinterpret_cast<__half*>(array + smem_offset + 16512); unsigned i6; i6 = toSmem(T5); const TensorMap* ptr7; ptr7 = &var1; nvfuser_index_t i8; i8 = 128 * ((nvfuser_index_t)blockIdx.y); __half* T4 = reinterpret_cast<__half*>(array + smem_offset + 128); unsigned i9; i9 = toSmem(T4); unsigned i10; i10 = i9 + (2048 * ((nvfuser_index_t)threadIdx.y)); nvfuser_index_t i11; i11 = ((nvfuser_index_t)threadIdx.x) / 4; nvfuser_index_t i12; i12 = 2 * (((nvfuser_index_t)threadIdx.x) % 4); nvfuser_index_t i13; i13 = i11 / 8; nvfuser_index_t i14; i14 = i11 % 8; nvfuser_index_t i15; i15 = ((((i12 + ((16 * T1.logical_size[2LL]) * i13)) + (T1.logical_size[2LL] * i14)) + ((64 * T1.logical_size[2LL]) * ((nvfuser_index_t)threadIdx.y))) + i5) + ((128 * T1.logical_size[2LL]) * ((nvfuser_index_t)blockIdx.y)); nvfuser_index_t i16; i16 = 8 * T1.logical_size[2LL]; bool b17; b17 = ((((nvfuser_index_t)threadIdx.x) < 32ULL) && (((nvfuser_index_t)threadIdx.y) == 0ULL)) && (((nvfuser_index_t)threadIdx.z) == 0ULL); nvfuser_index_t i18; i18 = ((1 - T1.logical_size[2LL]) + i12) + i5; nvfuser_index_t i19; i19 = ((((-T0.logical_size[1LL]) + (16 * i13)) + i14) + (64 * ((nvfuser_index_t)threadIdx.y))) + i8; float T2[128]; ((*reinterpret_cast*>(&T2[0]))).set(0); asm volatile("wgmma.fence.sync.aligned;\n"); asm volatile("fence.proxy.async;\n"); uint64_t* T7 = reinterpret_cast(array + smem_offset + 0); #pragma unroll for(nvfuser_index_t i20 = 0; i20 < 4; ++i20) { if ((b17 && Hopper::electSync(4294967295U))) { mbarrier::init(toSmem((&T7[i20])), 512U); } } __syncthreads(); #pragma unroll for(nvfuser_index_t i21 = 0; i21 < 3; ++i21) { nvfuser_index_t i22; i22 = 16 * i21; unsigned i23; i23 = i6 + (8192 * i21); unsigned i24; i24 = i9 + (4096 * i21); if ((b17 && Hopper::electSync(4294967295U))) { mbarrier::arriveExpectTX(toSmem((&T7[i21])), 8192U); #pragma unroll for(nvfuser_index_t i25 = 0; i25 < 4; ++i25) { Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr4, (Array{(i5 + (64 * i25)), i22}), toSmem((&T7[i21])) }), (i23 + (2048 * i25))); } } else { mbarrier::arrive(toSmem((&T7[i21]))); } if ((b17 && Hopper::electSync(4294967295U))) { mbarrier::arriveExpectTX(toSmem((&T7[i21])), 4096U); #pragma unroll for(nvfuser_index_t i26 = 0; i26 < 2; ++i26) { Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr7, (Array{(i8 + (64 * i26)), i22}), toSmem((&T7[i21])) }), (i24 + (2048 * i26))); } } else { mbarrier::arrive(toSmem((&T7[i21]))); } } #pragma unroll 4 for(nvfuser_index_t i27 = 0; i27 < i3; ++i27) { nvfuser_index_t i28; i28 = 48 + (16 * i27); nvfuser_index_t i29; i29 = (3 + i27) % 4; unsigned i30; i30 = i6 + (8192 * i29); unsigned i31; i31 = i9 + (4096 * i29); nvfuser_index_t i32; i32 = i27 % 4; unsigned i33; i33 = i10 + (4096 * i32); unsigned i34; i34 = i6 + (8192 * i32); if ((b17 && Hopper::electSync(4294967295U))) { mbarrier::arriveExpectTX(toSmem((&T7[((3 + i27) % 4)])), 8192U); #pragma unroll for(nvfuser_index_t i25 = 0; i25 < 4; ++i25) { Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr4, (Array{(i5 + (64 * i25)), i28}), toSmem((&T7[((3 + i27) % 4)])) }), (i30 + (2048 * i25))); } } else { mbarrier::arrive(toSmem((&T7[((3 + i27) % 4)]))); } if ((b17 && Hopper::electSync(4294967295U))) { mbarrier::arriveExpectTX(toSmem((&T7[((3 + i27) % 4)])), 4096U); #pragma unroll for(nvfuser_index_t i26 = 0; i26 < 2; ++i26) { Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr7, (Array{(i8 + (64 * i26)), i28}), toSmem((&T7[((3 + i27) % 4)])) }), (i31 + (2048 * i26))); } } else { mbarrier::arrive(toSmem((&T7[((3 + i27) % 4)]))); } mbarrier::waitParity(toSmem((&T7[(i27 % 4)])), (((uint32_t)(i27) / 4U) % 2U)); asm volatile( "{\n" " .reg .pred p0; \n" " setp.ne.b32 p0, %130, 0;\n" " wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, %128, %129, p0, %131, %132, %133, %134;\n" "}\n" :"+f"((*reinterpret_cast*>(&T2[0]))[0]), "+f"((*reinterpret_cast*>(&T2[0]))[1]), "+f"((*reinterpret_cast*>(&T2[0]))[2]), "+f"((*reinterpret_cast*>(&T2[0]))[3]), "+f"((*reinterpret_cast*>(&T2[0]))[4]), "+f"((*reinterpret_cast*>(&T2[0]))[5]), "+f"((*reinterpret_cast*>(&T2[0]))[6]), "+f"((*reinterpret_cast*>(&T2[0]))[7]), "+f"((*reinterpret_cast*>(&T2[0]))[8]), "+f"((*reinterpret_cast*>(&T2[0]))[9]), "+f"((*reinterpret_cast*>(&T2[0]))[10]), "+f"((*reinterpret_cast*>(&T2[0]))[11]), "+f"((*reinterpret_cast*>(&T2[0]))[12]), "+f"((*reinterpret_cast*>(&T2[0]))[13]), "+f"((*reinterpret_cast*>(&T2[0]))[14]), "+f"((*reinterpret_cast*>(&T2[0]))[15]), "+f"((*reinterpret_cast*>(&T2[0]))[16]), "+f"((*reinterpret_cast*>(&T2[0]))[17]), "+f"((*reinterpret_cast*>(&T2[0]))[18]), "+f"((*reinterpret_cast*>(&T2[0]))[19]), "+f"((*reinterpret_cast*>(&T2[0]))[20]), "+f"((*reinterpret_cast*>(&T2[0]))[21]), "+f"((*reinterpret_cast*>(&T2[0]))[22]), "+f"((*reinterpret_cast*>(&T2[0]))[23]), "+f"((*reinterpret_cast*>(&T2[0]))[24]), "+f"((*reinterpret_cast*>(&T2[0]))[25]), "+f"((*reinterpret_cast*>(&T2[0]))[26]), "+f"((*reinterpret_cast*>(&T2[0]))[27]), "+f"((*reinterpret_cast*>(&T2[0]))[28]), "+f"((*reinterpret_cast*>(&T2[0]))[29]), "+f"((*reinterpret_cast*>(&T2[0]))[30]), "+f"((*reinterpret_cast*>(&T2[0]))[31]), "+f"((*reinterpret_cast*>(&T2[0]))[32]), "+f"((*reinterpret_cast*>(&T2[0]))[33]), "+f"((*reinterpret_cast*>(&T2[0]))[34]), "+f"((*reinterpret_cast*>(&T2[0]))[35]), "+f"((*reinterpret_cast*>(&T2[0]))[36]), "+f"((*reinterpret_cast*>(&T2[0]))[37]), "+f"((*reinterpret_cast*>(&T2[0]))[38]), "+f"((*reinterpret_cast*>(&T2[0]))[39]), "+f"((*reinterpret_cast*>(&T2[0]))[40]), "+f"((*reinterpret_cast*>(&T2[0]))[41]), "+f"((*reinterpret_cast*>(&T2[0]))[42]), "+f"((*reinterpret_cast*>(&T2[0]))[43]), "+f"((*reinterpret_cast*>(&T2[0]))[44]), "+f"((*reinterpret_cast*>(&T2[0]))[45]), "+f"((*reinterpret_cast*>(&T2[0]))[46]), "+f"((*reinterpret_cast*>(&T2[0]))[47]), "+f"((*reinterpret_cast*>(&T2[0]))[48]), "+f"((*reinterpret_cast*>(&T2[0]))[49]), "+f"((*reinterpret_cast*>(&T2[0]))[50]), "+f"((*reinterpret_cast*>(&T2[0]))[51]), "+f"((*reinterpret_cast*>(&T2[0]))[52]), "+f"((*reinterpret_cast*>(&T2[0]))[53]), "+f"((*reinterpret_cast*>(&T2[0]))[54]), "+f"((*reinterpret_cast*>(&T2[0]))[55]), "+f"((*reinterpret_cast*>(&T2[0]))[56]), "+f"((*reinterpret_cast*>(&T2[0]))[57]), "+f"((*reinterpret_cast*>(&T2[0]))[58]), "+f"((*reinterpret_cast*>(&T2[0]))[59]), "+f"((*reinterpret_cast*>(&T2[0]))[60]), "+f"((*reinterpret_cast*>(&T2[0]))[61]), "+f"((*reinterpret_cast*>(&T2[0]))[62]), "+f"((*reinterpret_cast*>(&T2[0]))[63]), "+f"((*reinterpret_cast*>(&T2[0]))[64]), "+f"((*reinterpret_cast*>(&T2[0]))[65]), "+f"((*reinterpret_cast*>(&T2[0]))[66]), "+f"((*reinterpret_cast*>(&T2[0]))[67]), "+f"((*reinterpret_cast*>(&T2[0]))[68]), "+f"((*reinterpret_cast*>(&T2[0]))[69]), "+f"((*reinterpret_cast*>(&T2[0]))[70]), "+f"((*reinterpret_cast*>(&T2[0]))[71]), "+f"((*reinterpret_cast*>(&T2[0]))[72]), "+f"((*reinterpret_cast*>(&T2[0]))[73]), "+f"((*reinterpret_cast*>(&T2[0]))[74]), "+f"((*reinterpret_cast*>(&T2[0]))[75]), "+f"((*reinterpret_cast*>(&T2[0]))[76]), "+f"((*reinterpret_cast*>(&T2[0]))[77]), "+f"((*reinterpret_cast*>(&T2[0]))[78]), "+f"((*reinterpret_cast*>(&T2[0]))[79]), "+f"((*reinterpret_cast*>(&T2[0]))[80]), "+f"((*reinterpret_cast*>(&T2[0]))[81]), "+f"((*reinterpret_cast*>(&T2[0]))[82]), "+f"((*reinterpret_cast*>(&T2[0]))[83]), "+f"((*reinterpret_cast*>(&T2[0]))[84]), "+f"((*reinterpret_cast*>(&T2[0]))[85]), "+f"((*reinterpret_cast*>(&T2[0]))[86]), "+f"((*reinterpret_cast*>(&T2[0]))[87]), "+f"((*reinterpret_cast*>(&T2[0]))[88]), "+f"((*reinterpret_cast*>(&T2[0]))[89]), "+f"((*reinterpret_cast*>(&T2[0]))[90]), "+f"((*reinterpret_cast*>(&T2[0]))[91]), "+f"((*reinterpret_cast*>(&T2[0]))[92]), "+f"((*reinterpret_cast*>(&T2[0]))[93]), "+f"((*reinterpret_cast*>(&T2[0]))[94]), "+f"((*reinterpret_cast*>(&T2[0]))[95]), "+f"((*reinterpret_cast*>(&T2[0]))[96]), "+f"((*reinterpret_cast*>(&T2[0]))[97]), "+f"((*reinterpret_cast*>(&T2[0]))[98]), "+f"((*reinterpret_cast*>(&T2[0]))[99]), "+f"((*reinterpret_cast*>(&T2[0]))[100]), "+f"((*reinterpret_cast*>(&T2[0]))[101]), "+f"((*reinterpret_cast*>(&T2[0]))[102]), "+f"((*reinterpret_cast*>(&T2[0]))[103]), "+f"((*reinterpret_cast*>(&T2[0]))[104]), "+f"((*reinterpret_cast*>(&T2[0]))[105]), "+f"((*reinterpret_cast*>(&T2[0]))[106]), "+f"((*reinterpret_cast*>(&T2[0]))[107]), "+f"((*reinterpret_cast*>(&T2[0]))[108]), "+f"((*reinterpret_cast*>(&T2[0]))[109]), "+f"((*reinterpret_cast*>(&T2[0]))[110]), "+f"((*reinterpret_cast*>(&T2[0]))[111]), "+f"((*reinterpret_cast*>(&T2[0]))[112]), "+f"((*reinterpret_cast*>(&T2[0]))[113]), "+f"((*reinterpret_cast*>(&T2[0]))[114]), "+f"((*reinterpret_cast*>(&T2[0]))[115]), "+f"((*reinterpret_cast*>(&T2[0]))[116]), "+f"((*reinterpret_cast*>(&T2[0]))[117]), "+f"((*reinterpret_cast*>(&T2[0]))[118]), "+f"((*reinterpret_cast*>(&T2[0]))[119]), "+f"((*reinterpret_cast*>(&T2[0]))[120]), "+f"((*reinterpret_cast*>(&T2[0]))[121]), "+f"((*reinterpret_cast*>(&T2[0]))[122]), "+f"((*reinterpret_cast*>(&T2[0]))[123]), "+f"((*reinterpret_cast*>(&T2[0]))[124]), "+f"((*reinterpret_cast*>(&T2[0]))[125]), "+f"((*reinterpret_cast*>(&T2[0]))[126]), "+f"((*reinterpret_cast*>(&T2[0]))[127]) :"l"((4611686293305294848ULL | ((262143ULL & (uint64_t)(i33)) >> 4ULL))), "l"((4611686293313683456ULL | ((262143ULL & (uint64_t)(i34)) >> 4ULL))), "n"((uint32_t)(true)), "n"(1), "n"(1), "n"(1), "n"(1) ); __syncthreads(); asm volatile("wgmma.commit_group.sync.aligned;\n"); asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory"); } #pragma unroll 3 for(nvfuser_index_t i35 = (i2 - 3); i35 < i2; ++i35) { nvfuser_index_t i36; i36 = i35 % 4; unsigned i37; i37 = i10 + (4096 * i36); unsigned i38; i38 = i6 + (8192 * i36); mbarrier::waitParity(toSmem((&T7[(i35 % 4)])), (((uint32_t)(i35) / 4U) % 2U)); asm volatile( "{\n" " .reg .pred p0; \n" " setp.ne.b32 p0, %130, 0;\n" " wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, %128, %129, p0, %131, %132, %133, %134;\n" "}\n" :"+f"((*reinterpret_cast*>(&T2[0]))[0]), "+f"((*reinterpret_cast*>(&T2[0]))[1]), "+f"((*reinterpret_cast*>(&T2[0]))[2]), "+f"((*reinterpret_cast*>(&T2[0]))[3]), "+f"((*reinterpret_cast*>(&T2[0]))[4]), "+f"((*reinterpret_cast*>(&T2[0]))[5]), "+f"((*reinterpret_cast*>(&T2[0]))[6]), "+f"((*reinterpret_cast*>(&T2[0]))[7]), "+f"((*reinterpret_cast*>(&T2[0]))[8]), "+f"((*reinterpret_cast*>(&T2[0]))[9]), "+f"((*reinterpret_cast*>(&T2[0]))[10]), "+f"((*reinterpret_cast*>(&T2[0]))[11]), "+f"((*reinterpret_cast*>(&T2[0]))[12]), "+f"((*reinterpret_cast*>(&T2[0]))[13]), "+f"((*reinterpret_cast*>(&T2[0]))[14]), "+f"((*reinterpret_cast*>(&T2[0]))[15]), "+f"((*reinterpret_cast*>(&T2[0]))[16]), "+f"((*reinterpret_cast*>(&T2[0]))[17]), "+f"((*reinterpret_cast*>(&T2[0]))[18]), "+f"((*reinterpret_cast*>(&T2[0]))[19]), "+f"((*reinterpret_cast*>(&T2[0]))[20]), "+f"((*reinterpret_cast*>(&T2[0]))[21]), "+f"((*reinterpret_cast*>(&T2[0]))[22]), "+f"((*reinterpret_cast*>(&T2[0]))[23]), "+f"((*reinterpret_cast*>(&T2[0]))[24]), "+f"((*reinterpret_cast*>(&T2[0]))[25]), "+f"((*reinterpret_cast*>(&T2[0]))[26]), "+f"((*reinterpret_cast*>(&T2[0]))[27]), "+f"((*reinterpret_cast*>(&T2[0]))[28]), "+f"((*reinterpret_cast*>(&T2[0]))[29]), "+f"((*reinterpret_cast*>(&T2[0]))[30]), "+f"((*reinterpret_cast*>(&T2[0]))[31]), "+f"((*reinterpret_cast*>(&T2[0]))[32]), "+f"((*reinterpret_cast*>(&T2[0]))[33]), "+f"((*reinterpret_cast*>(&T2[0]))[34]), "+f"((*reinterpret_cast*>(&T2[0]))[35]), "+f"((*reinterpret_cast*>(&T2[0]))[36]), "+f"((*reinterpret_cast*>(&T2[0]))[37]), "+f"((*reinterpret_cast*>(&T2[0]))[38]), "+f"((*reinterpret_cast*>(&T2[0]))[39]), "+f"((*reinterpret_cast*>(&T2[0]))[40]), "+f"((*reinterpret_cast*>(&T2[0]))[41]), "+f"((*reinterpret_cast*>(&T2[0]))[42]), "+f"((*reinterpret_cast*>(&T2[0]))[43]), "+f"((*reinterpret_cast*>(&T2[0]))[44]), "+f"((*reinterpret_cast*>(&T2[0]))[45]), "+f"((*reinterpret_cast*>(&T2[0]))[46]), "+f"((*reinterpret_cast*>(&T2[0]))[47]), "+f"((*reinterpret_cast*>(&T2[0]))[48]), "+f"((*reinterpret_cast*>(&T2[0]))[49]), "+f"((*reinterpret_cast*>(&T2[0]))[50]), "+f"((*reinterpret_cast*>(&T2[0]))[51]), "+f"((*reinterpret_cast*>(&T2[0]))[52]), "+f"((*reinterpret_cast*>(&T2[0]))[53]), "+f"((*reinterpret_cast*>(&T2[0]))[54]), "+f"((*reinterpret_cast*>(&T2[0]))[55]), "+f"((*reinterpret_cast*>(&T2[0]))[56]), "+f"((*reinterpret_cast*>(&T2[0]))[57]), "+f"((*reinterpret_cast*>(&T2[0]))[58]), "+f"((*reinterpret_cast*>(&T2[0]))[59]), "+f"((*reinterpret_cast*>(&T2[0]))[60]), "+f"((*reinterpret_cast*>(&T2[0]))[61]), "+f"((*reinterpret_cast*>(&T2[0]))[62]), "+f"((*reinterpret_cast*>(&T2[0]))[63]), "+f"((*reinterpret_cast*>(&T2[0]))[64]), "+f"((*reinterpret_cast*>(&T2[0]))[65]), "+f"((*reinterpret_cast*>(&T2[0]))[66]), "+f"((*reinterpret_cast*>(&T2[0]))[67]), "+f"((*reinterpret_cast*>(&T2[0]))[68]), "+f"((*reinterpret_cast*>(&T2[0]))[69]), "+f"((*reinterpret_cast*>(&T2[0]))[70]), "+f"((*reinterpret_cast*>(&T2[0]))[71]), "+f"((*reinterpret_cast*>(&T2[0]))[72]), "+f"((*reinterpret_cast*>(&T2[0]))[73]), "+f"((*reinterpret_cast*>(&T2[0]))[74]), "+f"((*reinterpret_cast*>(&T2[0]))[75]), "+f"((*reinterpret_cast*>(&T2[0]))[76]), "+f"((*reinterpret_cast*>(&T2[0]))[77]), "+f"((*reinterpret_cast*>(&T2[0]))[78]), "+f"((*reinterpret_cast*>(&T2[0]))[79]), "+f"((*reinterpret_cast*>(&T2[0]))[80]), "+f"((*reinterpret_cast*>(&T2[0]))[81]), "+f"((*reinterpret_cast*>(&T2[0]))[82]), "+f"((*reinterpret_cast*>(&T2[0]))[83]), "+f"((*reinterpret_cast*>(&T2[0]))[84]), "+f"((*reinterpret_cast*>(&T2[0]))[85]), "+f"((*reinterpret_cast*>(&T2[0]))[86]), "+f"((*reinterpret_cast*>(&T2[0]))[87]), "+f"((*reinterpret_cast*>(&T2[0]))[88]), "+f"((*reinterpret_cast*>(&T2[0]))[89]), "+f"((*reinterpret_cast*>(&T2[0]))[90]), "+f"((*reinterpret_cast*>(&T2[0]))[91]), "+f"((*reinterpret_cast*>(&T2[0]))[92]), "+f"((*reinterpret_cast*>(&T2[0]))[93]), "+f"((*reinterpret_cast*>(&T2[0]))[94]), "+f"((*reinterpret_cast*>(&T2[0]))[95]), "+f"((*reinterpret_cast*>(&T2[0]))[96]), "+f"((*reinterpret_cast*>(&T2[0]))[97]), "+f"((*reinterpret_cast*>(&T2[0]))[98]), "+f"((*reinterpret_cast*>(&T2[0]))[99]), "+f"((*reinterpret_cast*>(&T2[0]))[100]), "+f"((*reinterpret_cast*>(&T2[0]))[101]), "+f"((*reinterpret_cast*>(&T2[0]))[102]), "+f"((*reinterpret_cast*>(&T2[0]))[103]), "+f"((*reinterpret_cast*>(&T2[0]))[104]), "+f"((*reinterpret_cast*>(&T2[0]))[105]), "+f"((*reinterpret_cast*>(&T2[0]))[106]), "+f"((*reinterpret_cast*>(&T2[0]))[107]), "+f"((*reinterpret_cast*>(&T2[0]))[108]), "+f"((*reinterpret_cast*>(&T2[0]))[109]), "+f"((*reinterpret_cast*>(&T2[0]))[110]), "+f"((*reinterpret_cast*>(&T2[0]))[111]), "+f"((*reinterpret_cast*>(&T2[0]))[112]), "+f"((*reinterpret_cast*>(&T2[0]))[113]), "+f"((*reinterpret_cast*>(&T2[0]))[114]), "+f"((*reinterpret_cast*>(&T2[0]))[115]), "+f"((*reinterpret_cast*>(&T2[0]))[116]), "+f"((*reinterpret_cast*>(&T2[0]))[117]), "+f"((*reinterpret_cast*>(&T2[0]))[118]), "+f"((*reinterpret_cast*>(&T2[0]))[119]), "+f"((*reinterpret_cast*>(&T2[0]))[120]), "+f"((*reinterpret_cast*>(&T2[0]))[121]), "+f"((*reinterpret_cast*>(&T2[0]))[122]), "+f"((*reinterpret_cast*>(&T2[0]))[123]), "+f"((*reinterpret_cast*>(&T2[0]))[124]), "+f"((*reinterpret_cast*>(&T2[0]))[125]), "+f"((*reinterpret_cast*>(&T2[0]))[126]), "+f"((*reinterpret_cast*>(&T2[0]))[127]) :"l"((4611686293305294848ULL | ((262143ULL & (uint64_t)(i37)) >> 4ULL))), "l"((4611686293313683456ULL | ((262143ULL & (uint64_t)(i38)) >> 4ULL))), "n"((uint32_t)(true)), "n"(1), "n"(1), "n"(1), "n"(1) ); __syncthreads(); } #pragma unroll for(nvfuser_index_t i39 = 0; i39 < 4; ++i39) { if ((b17 && Hopper::electSync(4294967295U))) { mbarrier::inval(toSmem((&T7[i39]))); } } asm volatile("wgmma.commit_group.sync.aligned;\n"); asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory"); #pragma unroll for(nvfuser_index_t i40 = 0; i40 < 32; ++i40) { nvfuser_index_t i41; i41 = 4 * i40; nvfuser_index_t i42; i42 = 8 * i40; nvfuser_index_t i43; i43 = i15 + i42; bool b44; b44 = i18 < (-i42); #pragma unroll for(nvfuser_index_t i45 = 0; i45 < 2; ++i45) { nvfuser_index_t i46; i46 = i41 + (2 * i45); Array<__half, 2, 2> T6; #pragma unroll for(nvfuser_index_t i47 = 0; i47 < 2; ++i47) { T6[i47] = __float2half(T2[(i46 + i47)]); } if ((b44 && (i19 < (-(8 * i45))))) { loadLocalToGlobal<__half, /*vec_size=*/2, /*is_volatile=*/false>( &T3[(i43 + (i16 * i45))], &T6[0]); } } } } ```
--- csrc/device_lower/pass/scalar_hoist.cpp | 14 ++++++++++---- csrc/ir/base_nodes.cpp | 10 +--------- csrc/ir/utils.cpp | 15 +++++++++++++++ csrc/ir/utils.h | 4 ++++ 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/csrc/device_lower/pass/scalar_hoist.cpp b/csrc/device_lower/pass/scalar_hoist.cpp index 898e1f4d5ec..c2528e4b6fd 100644 --- a/csrc/device_lower/pass/scalar_hoist.cpp +++ b/csrc/device_lower/pass/scalar_hoist.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -119,6 +120,9 @@ ForLoop* getLoopAtPos(const std::vector& loops, int64_t position) { // Check if in the definition of from, there is a subexpression equivalent to // reference. If found, then return this subexpression. Val* findRefAsSubexprOf(Val* from, Val* reference, bool exact) { + if (!ir_utils::isFunctional(reference)) { + return nullptr; + } if (exact) { if (from == reference) { return from; @@ -267,7 +271,9 @@ std::pair CommonScalarMap::hoistScalarImpl( // `common_scalar_map_` only if it can be hoisted to outer loops. if (!has_tensor_index_dependency && (is_given || my_pos < parent_pos)) { common_scalar_map_[my_loop].emplace_back(value); - if (my_pos < parent_pos) { + // We never hoist non-functional values because each call returns a + // different value, therefore non-hoistable. + if (my_pos < parent_pos && ir_utils::isFunctional(value)) { hoisted_or_reused_.emplace(value); } } @@ -386,10 +392,10 @@ Val* CommonScalarMap::reuseScalarIfAlreadyComputed(Val* value, ForLoop* loop) { if (it != common_scalar_map_.end()) { auto& scalars = it->second; for (auto it = scalars.begin(); it != scalars.end(); it++) { - auto idx = *it; - auto common_subexpr = findRefAsSubexprOf(idx, value, false); + auto scalar = *it; + auto common_subexpr = findRefAsSubexprOf(scalar, value, false); if (common_subexpr != nullptr) { - if (common_subexpr != idx) { + if (common_subexpr != scalar) { // If the reuse is a subexpression instead of the complete // expression, we split this subexpression out and allocate it // separately. diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index b1aa2a2513b..6e7e53e0f4d 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -190,17 +190,9 @@ std::string Val::toInlineString(int indent_size) const { } bool Val::isConstScalar() const { - if (!isScalar()) { + if (!isScalar() || !ir_utils::isFunctional(this)) { return false; } - // elect.sync ptx picks a leader thread from membermask. - // It cannot be evaluated at compile-time. - if (Expr* def = definition()) { - if (def->isA() && - def->as()->getUnaryOpType() == UnaryOpType::ElectSync) { - return false; - } - } return ir_utils::dependenciesSatisfied(this); } diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index a9b0dfb3d45..d52c7858923 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1189,6 +1189,21 @@ std::string nullOrToInlineString(const Statement* id) { return id ? id->toInlineString() : "nullptr"; } +bool isFunctional(const Val* v) { + auto def = v->definition(); + if (def == nullptr) { + return true; + } + if (auto uop = dynamic_cast(def)) { + // ElectSync is not functional, it does not return the same value + // every time it is called, so we do not want to reuse it. + if (uop->getUnaryOpType() == UnaryOpType::ElectSync) { + return false; + } + } + return std::all_of(def->inputs().begin(), def->inputs().end(), isFunctional); +} + } // namespace nvfuser::ir_utils namespace nvfuser::MmaOpUtils { diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 8c674e211f0..74dcf5abb9d 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -726,4 +726,8 @@ std::string nullOrToString(const Statement* stmt); //! toInlineString() std::string nullOrToInlineString(const Statement* stmt); +//! Check if the given value is functional. A functional value is one that +//! always returns the same result when called with the same inputs. +bool isFunctional(const Val* v); + } // namespace nvfuser::ir_utils From fcd5c781dd52cd11ed02653db16c5204ae676383 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Mon, 28 Oct 2024 19:03:25 -0700 Subject: [PATCH 03/17] Sequence Parallel MLP (#3259) Sequence parallel MLP test Kernel dump: ``` void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor, at::detail::Array>(int, T2, T3) ncclDevKernel_AllGather_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>) [Linear0] ampere_bf16_s16816gemm_bf16_64x64_sliced1x2_ldg8_relu_f2f_stages_64x6_tn void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor, at::detail::Array>(int, T2, T3) [GeLU] ::nvfuser_pointwise_f0_c1_r0_g7(::Tensor<::__bfloat, (int)3, (int)3>, ::Tensor<::__bfloat, (int)3, (int)3>) [Linear1-matmul]ampere_bf16_s16816gemm_bf16_64x64_sliced1x2_ldg8_f2f_stages_64x6_tn void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor, at::detail::Array>(int, T2, T3) [Linear1-reduction]ncclDevKernel_ReduceScatter_Sum_bf16_RING_LL(ncclDevKernelArgsStorage<(unsigned long)4096>) void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor, at::detail::Array>(int, T2, T3) void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor, at::detail::Array>(int, T2, T3) [Linear1-bias add, dropout]::nvfuser_pointwise_f0_c1_r0_g8(::Tensor<::__bfloat, (int)1, (int)1>, ::Tensor<::__bfloat, (int)3, (int)3>, long long *, long long, long long *, long long, ::Tensor, ::Tensor) ``` --- tests/cpp/test_multidevice_transformer.cpp | 142 +++++++++++++++++++-- 1 file changed, 132 insertions(+), 10 deletions(-) diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index a789365477f..f9d5d96c5da 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -278,29 +278,64 @@ std::vector mlp( TensorView* b0, TensorView* w1, TensorView* b1, - const DeviceMesh& mesh) { + const DeviceMesh& mesh, + bool sequence_parallel = false, + TensorView* mask = nullptr) { const DataType dtype = w0->dtype(); - // // Linear 0 + + if (sequence_parallel) { + // Input arrives sharded and must be allgathered back + x->setDeviceMesh(mesh); + x->axis(0)->parallelize(ParallelType::DIDx); + x = set(x); // allgather + x->axis(0)->parallelize(ParallelType::Serial); + // Reshape back to 2D for linearOp + auto D = w0->axis(0)->extent()->value().as(); + x = reshape(x, {D, B * S / D, E}, {B * S, E}); + } + // Linear 0 TensorView* linear0 = linear(x, w0, b0); // GeLU TensorView* gelu = tanh_gelu(castOp(DataType::Float, linear0)); gelu = castOp(dtype, gelu); // Linear 1 TensorView* local_matmul1 = matmul(gelu, transpose(w1, 1, 2)); - TensorView* matmul1 = sum(local_matmul1, {0}); // Allreduce - TensorView* linear1 = add(matmul1, broadcast(b1, {true, false})); + if (sequence_parallel) { + // Remove after https://github.com/NVIDIA/Fuser/issues/2563 + // Reshape to explicitly pull the sharded axis into the logical domain + auto D = w0->axis(0)->extent()->value().as(); + local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); + } + TensorView* matmul1 = sum(local_matmul1, {0}); // Allreduce or Reduce scatter + TensorView* linear1 = add(matmul1, broadcast(b1, {true, true, false})); // Dropout Val* prob = IrBuilder::create(1.0 - kDropoutProb); Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - auto dropout_result = dropout(linear1, prob, scale).output; + if (mask == nullptr) { + auto rand_vals = rand_like(linear1); + mask = lt(rand_vals, prob); + } + auto apply_mask = mul(linear1, mask); + auto dropout_result = mul(apply_mask, scale); - // Manual sharding annotations - for (auto tv : {x, b1, linear1, dropout_result}) { + // Tensor parallel shardings + for (auto* tv : {w0, b0, w1, linear0, gelu}) { tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); } - for (auto tv : {w0, b0, w1, linear0, gelu}) { + for (auto* tv : {x, b1, linear1, dropout_result}) { tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); + } + + // Sequence parallel shardings + if (sequence_parallel) { + for (auto* tv : {linear1, dropout_result}) { + tv->axis(0)->parallelize(ParallelType::DIDx); + } + matmul1->setDeviceMesh(mesh); + matmul1->axis(1)->parallelize(ParallelType::DIDx); + mask->setDeviceMesh(mesh); + mask->axis(0)->parallelize(ParallelType::DIDx); } return {linear0, gelu, linear1, dropout_result}; @@ -669,6 +704,93 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { validate(expected_outputs, outputs, {0.01, 0.01, 0.02, 0.02}); } +TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) { + // TODO: Reshapes that form device axes when D=1 get optimized away causing + // failures. This won't be a problem after + // https://github.com/NVIDIA/Fuser/issues/2563. + if (D == 1) { + GTEST_SKIP() << "Requires >1 devices, D=" << D; + } + if ((4 * E) % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide 4*E=" << 4 * E; + } + if ((B * S) % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide B*S=" << B * S; + } + DataType dtype = GetParam(); + at::ScalarType at_dtype = data_type_to_aten(dtype); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto mesh = DeviceMesh::createForNumDevices(D); + + TensorView* x = makeContigConcreteTensor({D, B * S / D, E}, dtype); + TensorView* w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype); + TensorView* b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype); + TensorView* w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype); + TensorView* b1 = makeContigConcreteTensor({E}, dtype); + TensorView* mask = + makeContigConcreteTensor({D, B * S / D, E}, DataType::Bool); + + // Input x is sharded on B*S dimension. + // Note only the sequence (S) dimension that is sharded + // but to avoid DID parallelizations of inner logical axes + // B*S is sharded. + auto tvsout = mlp(x, w0, b0, w1, b1, mesh, true, mask); + + fusion->addInput(x); + fusion->addInput(w0); + fusion->addInput(b0); + fusion->addInput(w1); + fusion->addInput(b1); + fusion->addInput(mask); + + for (auto* tv : tvsout) { + fusion->addOutput(tv); + } + + // Ensure broadcasts of bias are sharded. + shardBetween({b1}, {tvsout[2]}, tvsout[2]); + // Needed to ensure that rand_like is sharded initially. + // sharding from linear1 to dropout like dropout + shardBetween({tvsout[2]}, {tvsout[3]}, tvsout[3]); + + auto options = + at::TensorOptions().dtype(at_dtype).device(communicator_->device()); + auto x_ = at::randn({B * S, E}, options); + auto w0_ = at::randn({4 * E, E}, options) * kParamScale; + auto b0_ = at::randn({4 * E}, options) * kParamScale; + auto w1_ = at::randn({E, 4 * E}, options) * kParamScale; + auto b1_ = at::randn({E}, options) * kParamScale; + + // Dropout is sharded among devices. + // For validation against ATen the sharded reference dropout mask is an input + // to the Fusion, but in regular setting it would be generated. + std::vector reference_outs = + reference_mlp(x_, w0_, b0_, w1_, b1_); + auto mask_ = reference_outs[4]; + + std::vector inputs = { + shardTensor(x_, 0, mesh), + shardTensor(w0_, 0, mesh), + shardTensor(b0_, 0, mesh), + shardTensor(w1_, 1, mesh), + b1_, + shardTensor(mask_, 0, mesh)}; + + std::vector expected_outputs = { + shardTensor(reference_outs[0], 1, mesh), + shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[2], 0, mesh), + shardTensor(reference_outs[3], 0, mesh)}; + + FusionExecutorCache fec(std::move(fusion)); + at::manual_seed(getATenRandomSeed()); + auto outputs = fec.runFusionWithInputs(inputs); + validate(expected_outputs, outputs, {0.01, 0.01, 0.02, 0.02}); +} + TEST_P(DistributedTransformerTest, MultiheadAttention) { if (H % D != 0) { GTEST_SKIP() << "Requires number of devices=" << D @@ -903,7 +1025,7 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { at::manual_seed(getATenRandomSeed()); auto out = fec.runFusionWithInputs(inputs); validate( - expected_outputs, out, {1e-5, 0.02, 1e-5, .01, .01, 0.1, 0.1, 0.1, 0.01}); + expected_outputs, out, {1e-5, 0.02, 1e-5, .01, .02, 0.2, 0.2, 0.2, 0.02}); } TEST_P(DistributedTransformerTest, Forward) { From b2d8b318f274ce5ed9e1bb6be224f593fa42744d Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Mon, 28 Oct 2024 19:30:33 -0700 Subject: [PATCH 04/17] Only support mul-sum distributed matmul test for ampere and hopper (#3296) --- tests/cpp/test_multidevice_matmul.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 222fc79d4c3..3032db30b94 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -62,6 +62,7 @@ TEST_F(DistributedMatmulTest, MulSum_LayoutTN_NoComms) { // MmaLayout::TN A(T), B(N), C(T) // A and C are sharded on dimension M // Tests local matmul with no communication + NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(8, 0, 10, 0); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto mesh = DeviceMesh::createForNumDevices(num_devices_); From f6975f37eab197052e7ee59bf2bc8c78c1491dbf Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:22:05 -0700 Subject: [PATCH 05/17] Adaptive layernorm host benchmark (#3229) This benchmark adds host benchmarking for the `adaptive layernorm forward` fusion. Screenshot 2024-10-21 at 2 23 47 PM --------- Co-authored-by: Kevin Stephano --- .../python/test_adaptive_layernorm_host.py | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 benchmarks/python/test_adaptive_layernorm_host.py diff --git a/benchmarks/python/test_adaptive_layernorm_host.py b/benchmarks/python/test_adaptive_layernorm_host.py new file mode 100644 index 00000000000..7e3c67b6d8b --- /dev/null +++ b/benchmarks/python/test_adaptive_layernorm_host.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +import pytest +from nvfuser import FusionDefinition, DataType +from nvfuser.pytorch_utils import clear_cuda_cache +from .core import run_benchmark +import torch + + +def adaptive_layernorm_fwd_fusion(fd: FusionDefinition, eps: float = 1e-6) -> None: + T0 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=[True, True, True], + dtype=DataType.Half, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.Half, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.Half, + is_cpu=False, + stride_order=[1, 0], + ) + T3 = fd.ops.cast(T0, dtype=DataType.Float) + T4, T5 = fd.ops.var_mean(T3, dims=[2], correction=0, keepdim=False) + T10 = fd.ops.broadcast_in_dim( + T4, shape=[T0.size(0), T0.size(1), 1], broadcast_dims=[0, 1] + ) + T15 = fd.ops.broadcast_in_dim( + T5, shape=[T0.size(0), T0.size(1), 1], broadcast_dims=[0, 1] + ) + S16 = fd.define_scalar(eps, dtype=DataType.Double) + T17 = fd.ops.add(T10, S16) + T22 = fd.ops.broadcast_in_dim(T15, shape=T0.shape(), broadcast_dims=[0, 1, 2]) + T23 = fd.ops.rsqrt(T17) + T24 = fd.ops.sub(T3, T22) + T29 = fd.ops.broadcast_in_dim(T23, shape=T0.shape(), broadcast_dims=[0, 1, 2]) + T30 = fd.ops.mul(T24, T29) + T35 = fd.ops.reshape(T1, new_shape=[T1.size(0), 1, T1.size(1)]) + T36 = fd.ops.cast(T35, dtype=DataType.Float) + S37 = fd.define_scalar(1.00000, dtype=DataType.Double) + T38 = fd.ops.add(S37, T36) + T39 = fd.ops.cast(T38, dtype=DataType.Half) + T44 = fd.ops.broadcast_in_dim(T39, shape=T0.shape(), broadcast_dims=[0, 1, 2]) + T45 = fd.ops.cast(T44, dtype=DataType.Float) + T46 = fd.ops.mul(T30, T45) + T51 = fd.ops.reshape(T2, new_shape=[T2.size(0), 1, T2.size(1)]) + T56 = fd.ops.broadcast_in_dim(T51, shape=T0.shape(), broadcast_dims=[0, 1, 2]) + T57 = fd.ops.cast(T56, dtype=DataType.Float) + T58 = fd.ops.add(T46, T57) + T59 = fd.ops.cast(T58, dtype=DataType.Half) + fd.add_output(T5) + fd.add_output(T23) + fd.add_output(T59) + + +# This benchmark is to particularly track nvFuser host overhead for shape +# change (dynamic shape support) in the adapative layernorm case. Running a +# new shape on this fusion without recompiling a new kernel can have significant overhead. +@pytest.mark.parametrize("host_bench_mode", ["compile", "steady", "dynamic"]) +def test_adaptive_layernorm_fwd_benchmark( + benchmark, + host_bench_mode: str, + disable_validation: bool, + disable_benchmarking: bool, +): + clear_cuda_cache() + + B = 1 + T = 30 * 1024 + D = 1024 + inputs = [ + torch.randn(B, T, D, device="cuda", dtype=torch.float16, requires_grad=True), + torch.randn(B, D, device="cuda", dtype=torch.float16, requires_grad=True), + torch.randn(B, D, device="cuda", dtype=torch.float16, requires_grad=True), + ] + + # Generate multiple inputs to measure dynamic shape overhead. + if host_bench_mode == "dynamic": + inputs = [] + for B in range(1, 3, 1): + for T in range(30 * 1024, 30 * 1024 + 5 * 128, 128): + inputs.append( + [ + torch.randn( + B, + T, + D, + device="cuda", + dtype=torch.float16, + requires_grad=True, + ), + torch.randn( + B, D, device="cuda", dtype=torch.float16, requires_grad=True + ), + torch.randn( + B, D, device="cuda", dtype=torch.float16, requires_grad=True + ), + ] + ) + + with FusionDefinition() as fd: + adaptive_layernorm_fwd_fusion(fd) + + def validate(input): + eps = 1e-6 + in_tensor, scale, shift = input + norm_state = torch.nn.LayerNorm(D, elementwise_affine=False, eps=eps) + norm_out = norm_state(in_tensor) + mean = in_tensor.to(torch.float).mean(dim=-1) + variance = in_tensor.to(torch.float).var(dim=-1, unbiased=False) + invstd = (1.0 / torch.sqrt(variance + eps)).unsqueeze(-1) + eager_output = norm_out * (1 + scale.view(-1, 1, D)) + shift.view(-1, 1, D) + fd.validate(input, [mean, invstd, eager_output]) + + if not disable_validation: + if host_bench_mode == "dynamic": + # Run validate for all input sizes. + for input in inputs: + validate(input) + else: + validate(inputs) + + if not disable_benchmarking: + run_benchmark( + benchmark, + None, + inputs, + device=f"host:{host_bench_mode}", + fusion_fn=adaptive_layernorm_fwd_fusion, + ) From 81d166789baffa1bd0c476fe1b7af30ee5a65253 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Tue, 29 Oct 2024 08:22:13 -0400 Subject: [PATCH 06/17] check vectorization factor of shared memory consumers to avoid illegal vectorization size (#3271) **Issue** InnerOuter persistent scheduler uses shared memory to store persistent buffers, the data flow is `input in gmem ---> async copy to smem --> vectorized load to registers (smem consumers)`, the `-->` are simply `LoadStoreOp` and same vectorization factors of these two copies are used. [CI](https://nv/e2E/118278383) found a case where the shared memory persistent buffers have a data type of fp32 while the inputs are fp16 (when there are view ops, project to inputs is not used). The vectorization factor is set to 8 and caused 32 bytes vectorization when loading from shared memory to registers. **Changes**: (1) Added code to handle the vectorization of smem consumers. Add an additional split if `smem --> regs` copy leads to vectorization larger than 16 bytes. (2) Added a test **Results**: Ensure vectorizations are <= 16 bytes. **Following works** See issue https://github.com/NVIDIA/Fuser/issues/3272 --------- Co-authored-by: Naoya Maruyama --- csrc/scheduler/normalization_inner_outer.cpp | 16 ++-- csrc/scheduler/normalization_utils.cpp | 1 + csrc/scheduler/reduction.cpp | 1 + csrc/scheduler/reduction_utils.cpp | 77 +++++++++++++++---- csrc/scheduler/reduction_utils.h | 24 +++++- .../test_combined_inner_outer_reduction.cpp | 76 +++++++++++++++++- 6 files changed, 163 insertions(+), 32 deletions(-) diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 625bd95985e..6dd34f4cab9 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -967,11 +967,7 @@ void scheduleInnerOuterPersistentKernel( scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); const auto& unroll_vectorizable_cached_tvs = reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - inner_reference_tv, - is_vectorize, - cached_inputs, - cached_outputs, - smem_consumers); + inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); reduction_scheduler_utils::propagateParallelization( inner_reduction_tvs[0], inner_reference_tv, @@ -998,8 +994,7 @@ void scheduleInnerOuterPersistentKernel( outer_reference_tvs[i], is_vectorize, cached_inputs, - cached_outputs, - smem_consumers); + cached_outputs); reduction_scheduler_utils::propagateParallelization( outer_reduction_tvs[i], outer_reference_tvs[i], @@ -1044,6 +1039,13 @@ void scheduleInnerOuterPersistentKernel( } } + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (is_vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, rparams->unroll_factor_inner_reduction); + } + // Remove dummy outputs as they can inadvertently affect CA positions for (auto output : dummy_outputs) { fusion->removeOutput(output); diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index 7bf100adca3..2601fcc469c 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -1420,6 +1420,7 @@ void schedulePersistentKernel( unroll, vectorize, is_outer_grid_persistence, + rparams->unroll_factor_inner_reduction, reduction_tvs, cached_inputs, cached_outputs, diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 8cca314f9ae..87f9d2bffad 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -1259,6 +1259,7 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) { unroll, vectorize, use_iter_grouped_reduction, + rparams->unroll_factor_inner_reduction, reduction_tvs, cached_inputs, cached_outputs); diff --git a/csrc/scheduler/reduction_utils.cpp b/csrc/scheduler/reduction_utils.cpp index f0d03e02ed8..34db3133da7 100644 --- a/csrc/scheduler/reduction_utils.cpp +++ b/csrc/scheduler/reduction_utils.cpp @@ -5,14 +5,14 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include - #include #include #include #include #include +#include #include +#include #include #include #include @@ -348,6 +348,7 @@ void multiReductionInliner( const bool is_unroll_or_vectorization, const bool vectorize, const bool use_grouped_reduction, + const int64_t vectorizatoin_factor, std::vector reduction_tvs, std::vector cached_inputs, std::vector> cached_outputs, @@ -361,7 +362,7 @@ void multiReductionInliner( } const auto& unroll_vectorizable_cached_tvs = getCachedTvsToUnrollOrVectorize( - reference_tv, vectorize, cached_inputs, cached_outputs, smem_consumers); + reference_tv, vectorize, cached_inputs, cached_outputs); reduction_scheduler_utils::propagateParallelization( reduction_tv, reference_tv, @@ -370,6 +371,13 @@ void multiReductionInliner( reduction_tvs, unroll_vectorizable_cached_tvs); + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, vectorizatoin_factor); + } + // Remove dummy outputs as they can inadvertently affect CA positions for (auto output : dummy_outputs) { fusion->removeOutput(output); @@ -428,8 +436,7 @@ std::unordered_set getCachedTvsToUnrollOrVectorize( TensorView* reference_tv, bool vectorize, const std::vector& cached_inputs, - const std::vector>& cached_outputs, - const std::vector& smem_consumers) { + const std::vector>& cached_outputs) { auto reduced_tv = ir_utils::getSoleProducerTv(reference_tv); // Grab all tensor views that should be vectorized auto vectorizable_inputs_outputs = @@ -469,18 +476,6 @@ std::unordered_set getCachedTvsToUnrollOrVectorize( } } - if (vectorize) { - for (auto tv : smem_consumers) { - // smem_consumers were added in schedule process - // movePersistentBufferToSmem() using cacheAfter() - NVF_ERROR( - vectorizable_expr(tv->definition()), - "Expected a vectorizable expression, but got: ", - tv->definition()->toString()); - unroll_vectorizable_tvs.emplace(tv); - } - } - return unroll_vectorizable_tvs; } @@ -1009,5 +1004,53 @@ std::ostream& operator<<(std::ostream& os, ReductionType reduction_type) { return os; } +void sharedMemoryConsumerVectorization( + std::vector& smem_consumers, + int64_t io_vectorization_factor) { + for (auto tv : smem_consumers) { + // they were creatd with cacheAfter. + NVF_ERROR( + tv->definition()->isA(), + "smem consumers should be LoadStoreOp. Got: ", + tv->definition()->toString()); + + // non-concretized broadcast domains are moved to the innermost before + // transform propagation, should skip these axes. + int64_t vect_axis_pos = -1; + while (tv->axis(vect_axis_pos)->isBroadcast()) { + vect_axis_pos--; + NVF_ERROR( + vect_axis_pos + tv->nDims() >= 0, + "Out of bound access when visiting dim ", + vect_axis_pos, + " in Tv: ", + tv->toString()); + } + // they were transformed with innermost axis has extent equal to + // vectorization factor set for io tvs. + NVF_ERROR( + tv->axis(vect_axis_pos)->extent()->isConst(), + "Extent of the innermost axis of smem consumers should be constant. Got: ", + tv->toString()); + auto innermost_extent = + tv->axis(vect_axis_pos)->extent()->evaluate().as(); + NVF_ERROR( + innermost_extent == io_vectorization_factor, + "Extent of the innermost axis of smem consumers should be equal to the vectorization factor of fuion inputs and outputs. Got: ", + innermost_extent, + ", expected: ", + io_vectorization_factor); + auto dtype_bytes = dataTypeSize(tv->getDataType().value()); + auto max_vect_factor = + SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_bytes; + // additional split is added if the innermost extent is greater than max + // vectorization factor. + if (innermost_extent > max_vect_factor) { + tv->split(vect_axis_pos, max_vect_factor); + } + tv->axis(vect_axis_pos)->parallelize(ParallelType::Vectorize); + } +} + } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/reduction_utils.h b/csrc/scheduler/reduction_utils.h index 78096210afb..713c399c03b 100644 --- a/csrc/scheduler/reduction_utils.h +++ b/csrc/scheduler/reduction_utils.h @@ -35,6 +35,7 @@ void multiReductionInliner( const bool unroll, const bool vectorize, const bool use_grouped_reduction, + const int64_t vectorizatoin_factor, std::vector reduction_tvs, std::vector cached_inputs, std::vector> cached_outputs, @@ -65,14 +66,11 @@ void propagateRFactor( // is_vectorize: Indicates if vectorization is applied in the scheduler. // cached_inputs: Inputs cached in registers or shared memory. // cached_outputs: Outputs cached in registers. -// smem_consumers: Consumers of shared memory persistent buffers, they are -// register cached Tvs after the shared memory tv. NVF_API std::unordered_set getCachedTvsToUnrollOrVectorize( TensorView* reference_tv, bool is_vectorize, const std::vector& cached_inputs, - const std::vector>& cached_outputs, - const std::vector& smem_consumers); + const std::vector>& cached_outputs); // Propagate parallelization from the reference TensorView to other TensorViews. // Unroll, Vectorize, and MisalignedVectorize types are explicitly handled for @@ -139,5 +137,23 @@ std::string toString(ReductionType reduction_type); ReductionType getReductionType(Fusion* fusion); ReductionType getReductionType(const std::vector& reduction_tvs); +/** + * @brief Vectorize shared memory consumers + * + * Applies vectorization to shared memory consumers. + * If extent of the last dim multiples vectorization factor exceeds hardware + * limitations, additional split is added. + * + * @param smem_consumers Vector of TensorView pointers representing shared + * memory consumers + * @param io_vectorization_factor Vectorization factor set for fusion inputs and + * outputs + * @note TODO: Optimize writing to shared memory and address bank conflicts for + * float32 with innermost extent of 8 + */ +void sharedMemoryConsumerVectorization( + std::vector& smem_consumers, + const int64_t io_vectorization_factor); + } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index c3ff5928742..2071aeb0e86 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -610,7 +610,7 @@ TEST_F(CombinedSchedulerTest, CombinedReduction) { false, inner_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_inner, true, cached_inputs, cached_outputs, {})); + reference_tv_inner, true, cached_inputs, cached_outputs)); reduction_scheduler_utils::propagateParallelization( outer_reduction_tv, reference_tv_outer, @@ -618,7 +618,7 @@ TEST_F(CombinedSchedulerTest, CombinedReduction) { false, outer_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_outer, true, cached_inputs, cached_outputs, {})); + reference_tv_outer, true, cached_inputs, cached_outputs)); inlineMost(); LaunchParams launch_constraints; @@ -773,7 +773,7 @@ TEST_F(CombinedSchedulerTest, CombinedReductionMultiPerBlock) { false, inner_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_inner, true, cached_inputs, cached_outputs, {}), + reference_tv_inner, true, cached_inputs, cached_outputs), {selected_tvs_inner.begin(), selected_tvs_inner.end()}); const auto& selected_tvs_outer = @@ -787,7 +787,7 @@ TEST_F(CombinedSchedulerTest, CombinedReductionMultiPerBlock) { false, outer_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_outer, true, cached_inputs, cached_outputs, {}), + reference_tv_outer, true, cached_inputs, cached_outputs), {selected_tvs_outer.begin(), selected_tvs_outer.end()}); std::vector cached_gmem_temp{partialResult}; @@ -926,4 +926,72 @@ TEST_F(CombinedSchedulerTest, InnerOuterNoOuterBroadcastTv) { "", persistent_params->lparams); } + +// Reproduce error found in: +// thunder/tests/test_torch_compile_executor.py::test_torch_compile_cat_nvfuser_phi2_tanh +// Only happens when shared memory persistent is used. +TEST_F(CombinedSchedulerTest, SharedMemoryPersistentVectFactor) { + Fusion fusion; + FusionGuard fg(&fusion); + // When the input is float16, the vectorization factor is set to 8. + // If the persistent buffer tv1 is stored in shared memory and is not + // projected to inputs, the scheduler adds a cacheAfter to load tv1 from + // shared memory to registers in a vectorized manner, avoiding bank conflicts. + // However, since tv1 is float32, we can't directly use the vectorization + // factor set for float16 inputs because the maximum allowed vectorization + // width is 16 bytes. + const int dim0 = 1024; + const int dim1 = 4096; + auto dtype = DataType::Half; + auto tv0 = makeContigTensor(2, dtype); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true}); + auto tv4 = add(tv3, tv1); + auto tv5 = sum(tv1, {0}); + auto tv6 = castOp(DataType::Half, tv4); + auto tv7 = castOp(DataType::Half, tv5); + fusion.addOutput(tv6); + fusion.addOutput(tv7); + + Fusion fusion_copy = fusion; + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dim0, dim1}, options); + std::vector aten_inputs = {t0}; + + SchedulerRuntimeInfo runtime_info(&fusion, aten_inputs); + ASSERT_TRUE(Schedule::canSchedule( + SchedulerType::InnerOuterPersistent, &fusion, runtime_info)); + auto scheduler = SchedulerEntry::makeSchedulerInstance( + SchedulerType::InnerOuterPersistent); + auto heuristic_params = scheduler->computeHeuristics(&fusion, runtime_info); + + // disable projection to inputs, so shared memory buffer is using float32 + heuristic_params->as()->project_persistent_buffers = false; + // Set vectorization factor to 8, so the exent of the innermost dimension + // exceed 16 bytes (8 x 4 = 32 bytes). + heuristic_params->as()->unroll_factor_inner_reduction = 8; + // when compute heuristics, the buffer is projected to inputs and the shared + // memory persistent buffer is the input, tv0. Then, we modified the + // heuristics to disable project to inputs, so needs to update the buffer + // being stored in shared memory to the original unprojected buffer, tv1. + heuristic_params->as()->smem_persistent_buffers = + std::vector{tv1}; + scheduler->schedule(&fusion, heuristic_params.get()); + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + + for (auto tv : fusion.allTvs()) { + if (tv->getMemoryType() == MemoryType::Shared) { + for (auto consumer : ir_utils::consumerTvsOf(tv)) { + EXPECT_TRUE(isVectorized(consumer)); + } + } + } + auto cg_outputs = fe.runFusion( + aten_inputs, heuristic_params->as()->lparams); + testValidate(&fusion_copy, cg_outputs, aten_inputs, __LINE__, __FILE__); +} } // namespace nvfuser From 5db18de5d8dafc6a8f1e76909d73799753fe7a5e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Tue, 29 Oct 2024 09:26:44 -0400 Subject: [PATCH 07/17] Introduce @retry_on_oom_or_skip_test (#3252) Inspired by https://github.com/NVIDIA/Fuser/pull/3174 This is an alternative to #3238. Previously we were manually resetting the cuda cache whenever the usage was above 80%. This is not ideal since we could have 79% usage and a test that requires 25% and that would fail. We also might clear the cache unnecessarily sometimes: e.g. we are using 81% but only need a few percent for the remainder of tests. This PR cleans this up by introducing a new test decorator `@retry_on_oom_or_skip_test`. This decorator must be placed innermost, underneath the other decorators. It will execute the test inside a try block. If the test fails due to `torch.OutOfMemoryError`, we clear the cuda cache and retry the test. If it fails again due to `torch.OutOfMemoryError`, then we skip the test. I updated the python benchmarks to apply this decorator automatically, and to remove the manual `clear_cuda_cache()` calls. --- benchmarks/python/conftest.py | 5 ++ benchmarks/python/normalization.py | 8 +-- benchmarks/python/test_broadcast_add_fwd.py | 5 +- .../python/test_dropout_layernorm_bwd.py | 5 +- .../python/test_dropout_layernorm_fwd.py | 4 +- benchmarks/python/test_dropout_rmsnorm_bwd.py | 5 +- benchmarks/python/test_dropout_rmsnorm_fwd.py | 5 +- benchmarks/python/test_gelu_bwd.py | 5 +- benchmarks/python/test_gelu_bwd_reduction.py | 5 +- benchmarks/python/test_gelu_fwd.py | 5 +- benchmarks/python/test_groupnorm_fwd.py | 6 +- .../python/test_huggingface_attn_bwd.py | 5 +- .../python/test_huggingface_attn_fwd.py | 5 +- benchmarks/python/test_layernorm_bwd.py | 5 +- benchmarks/python/test_layernorm_fwd.py | 5 +- benchmarks/python/test_many_pointwise_ops.py | 4 +- benchmarks/python/test_matmul.py | 63 ++++++++----------- benchmarks/python/test_nanogpt_attn_bwd.py | 4 +- benchmarks/python/test_nanogpt_attn_fwd.py | 4 +- benchmarks/python/test_pointwise_mul.py | 5 +- benchmarks/python/test_reduction.py | 5 +- benchmarks/python/test_reduction_epilogue.py | 4 +- benchmarks/python/test_rmsnorm_bwd.py | 4 +- benchmarks/python/test_rmsnorm_fwd.py | 5 +- benchmarks/python/test_rope.py | 3 - benchmarks/python/test_scale_bias_relu_bwd.py | 5 +- benchmarks/python/test_scale_bias_relu_fwd.py | 5 +- benchmarks/python/test_silu_mul_bwd.py | 4 +- benchmarks/python/test_silu_mul_fwd.py | 4 +- benchmarks/python/test_softmax_bwd.py | 5 +- benchmarks/python/test_softmax_fwd.py | 5 +- benchmarks/python/test_transpose.py | 5 +- nvfuser/pytorch_utils.py | 42 +++++++++---- tests/python/test_ops.py | 8 +-- version.txt | 2 +- 35 files changed, 95 insertions(+), 169 deletions(-) diff --git a/benchmarks/python/conftest.py b/benchmarks/python/conftest.py index 31596fef4cf..8932afbff30 100644 --- a/benchmarks/python/conftest.py +++ b/benchmarks/python/conftest.py @@ -104,6 +104,11 @@ def pytest_collection_modifyitems(session, config, items): run_thunder = config.getoption("--benchmark-thunder") run_torchcompile = config.getoption("--benchmark-torchcompile") + from nvfuser.pytorch_utils import retry_on_oom_or_skip_test + + for item in items: + item.obj = retry_on_oom_or_skip_test(item.obj) + if not run_eager: skip_eager = pytest.mark.skip(reason="need --benchmark-eager option to run") for item in items: diff --git a/benchmarks/python/normalization.py b/benchmarks/python/normalization.py index 5137fe5c8ab..8ec19ebf71d 100644 --- a/benchmarks/python/normalization.py +++ b/benchmarks/python/normalization.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause from nvfuser import FusionDefinition, DataType from .global_params import PROMOTE_DTYPES -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype import torch from .core import run_benchmark, unary_bwd_torch, clear_dynamo_cache import numpy as np @@ -206,8 +206,6 @@ def norm_fwd_nvf_benchmark( Common benchmark setup for batchnorm/instance forward call in training mode. """ - clear_cuda_cache() - assert norm in ["batch_norm", "instance_norm"], NotImplementedError # Size is assumed to be in the order N, C, ... @@ -293,8 +291,6 @@ def norm_bwd_nvf_benchmark( Common benchmark setup for batchnorm/instance forward call in training mode. """ - clear_cuda_cache() - assert norm in ["batch_norm", "instance_norm"], NotImplementedError # Size is assumed to be in the order N, C, ... @@ -440,7 +436,6 @@ def norm_fwd_baseline_benchmark( compile: bool, norm: str, ): - clear_cuda_cache() if compile: clear_dynamo_cache() @@ -475,7 +470,6 @@ def norm_bwd_baseline_benchmark( compile: bool, norm: str, ): - clear_cuda_cache() if compile: clear_dynamo_cache() diff --git a/benchmarks/python/test_broadcast_add_fwd.py b/benchmarks/python/test_broadcast_add_fwd.py index 6a6debe1f58..abb320ef2a3 100644 --- a/benchmarks/python/test_broadcast_add_fwd.py +++ b/benchmarks/python/test_broadcast_add_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -65,8 +65,6 @@ def test_bcast_add_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - bias = torch.randn(size[1 - bcast_axis], dtype=dtype, device="cuda") input_shape = size if contiguous else (size[1], size[0]) @@ -105,7 +103,6 @@ def test_bcast_add_baseline_benchmark( contiguous: bool, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() bias = torch.randn(size[1 - bcast_axis], dtype=dtype, device="cuda") diff --git a/benchmarks/python/test_dropout_layernorm_bwd.py b/benchmarks/python/test_dropout_layernorm_bwd.py index a9d5dc24cd9..dcff2abb5ba 100644 --- a/benchmarks/python/test_dropout_layernorm_bwd.py +++ b/benchmarks/python/test_dropout_layernorm_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import ( run_benchmark, clear_dynamo_cache, @@ -149,8 +149,6 @@ def test_dropout_layernorm_bwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() - input1 = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) input2 = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(size, device="cuda", dtype=dtype) @@ -200,7 +198,6 @@ def test_dropout_layernorm_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() diff --git a/benchmarks/python/test_dropout_layernorm_fwd.py b/benchmarks/python/test_dropout_layernorm_fwd.py index cb2d7ceb7eb..47854fcd2d7 100644 --- a/benchmarks/python/test_dropout_layernorm_fwd.py +++ b/benchmarks/python/test_dropout_layernorm_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import ( run_benchmark, clear_dynamo_cache, @@ -111,7 +111,6 @@ def test_dropout_layernorm_fwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() inputs = [ torch.randn(size, device="cuda", dtype=dtype), torch.randn(size, device="cuda", dtype=dtype), @@ -170,7 +169,6 @@ def test_dropout_layernorm_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() diff --git a/benchmarks/python/test_dropout_rmsnorm_bwd.py b/benchmarks/python/test_dropout_rmsnorm_bwd.py index 77a5ff091f4..d103c17dcfa 100644 --- a/benchmarks/python/test_dropout_rmsnorm_bwd.py +++ b/benchmarks/python/test_dropout_rmsnorm_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import ( run_benchmark, clear_dynamo_cache, @@ -135,8 +135,6 @@ def test_dropout_rmsnorm_bwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() - input1 = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) input2 = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(size, device="cuda", dtype=dtype) @@ -180,7 +178,6 @@ def test_dropout_rmsnorm_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() dropout_p = 0.2 diff --git a/benchmarks/python/test_dropout_rmsnorm_fwd.py b/benchmarks/python/test_dropout_rmsnorm_fwd.py index 3cb677aadb5..a93a8caf547 100644 --- a/benchmarks/python/test_dropout_rmsnorm_fwd.py +++ b/benchmarks/python/test_dropout_rmsnorm_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import ( run_benchmark, clear_dynamo_cache, @@ -111,8 +111,6 @@ def test_dropout_rmsnorm_fwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() - input1 = torch.randn(size, device="cuda", dtype=dtype) input2 = torch.randn(size, device="cuda", dtype=dtype) weights = torch.randn(size[1], device="cuda", dtype=dtype) @@ -156,7 +154,6 @@ def test_dropout_rmsnorm_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() dropout_p = 0.2 diff --git a/benchmarks/python/test_gelu_bwd.py b/benchmarks/python/test_gelu_bwd.py index a43d484c46e..a876f00d748 100644 --- a/benchmarks/python/test_gelu_bwd.py +++ b/benchmarks/python/test_gelu_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -71,8 +71,6 @@ def test_gelu_bwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(size, device="cuda", dtype=dtype) bias = torch.ones(size[-1], device="cuda", dtype=dtype) @@ -99,7 +97,6 @@ def test_gelu_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) diff --git a/benchmarks/python/test_gelu_bwd_reduction.py b/benchmarks/python/test_gelu_bwd_reduction.py index 8b8aba0c59a..09dfd53d88a 100644 --- a/benchmarks/python/test_gelu_bwd_reduction.py +++ b/benchmarks/python/test_gelu_bwd_reduction.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -83,8 +83,6 @@ def test_gelu_bwd_reduction_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(*size, device="cuda", dtype=dtype) bias = torch.ones(size[-1], device="cuda", dtype=dtype) @@ -116,7 +114,6 @@ def test_gelu_bwd_reduction_baseline_benchmark( reduction_axis: int, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) diff --git a/benchmarks/python/test_gelu_fwd.py b/benchmarks/python/test_gelu_fwd.py index c5bc5c24823..fa5f891ef8a 100644 --- a/benchmarks/python/test_gelu_fwd.py +++ b/benchmarks/python/test_gelu_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -54,8 +54,6 @@ def test_gelu_fwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = [ torch.randn(size, device="cuda", dtype=dtype, requires_grad=True), # in_tensor torch.ones(size[-1], device="cuda", dtype=dtype), # bias @@ -78,7 +76,6 @@ def test_gelu_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = [ diff --git a/benchmarks/python/test_groupnorm_fwd.py b/benchmarks/python/test_groupnorm_fwd.py index ac9f21ded07..af4c023d7d7 100644 --- a/benchmarks/python/test_groupnorm_fwd.py +++ b/benchmarks/python/test_groupnorm_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch import thunder @@ -111,8 +111,6 @@ def test_groupnorm_fwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() - N, C, H, W = size x = torch.randn(size, device="cuda", dtype=dtype) weight = torch.randn(C, device="cuda", dtype=dtype) @@ -137,7 +135,6 @@ def test_groupnorm_fwd_thunder_benchmark( size: tuple, dtype: torch.dtype, ): - clear_cuda_cache() N, C, H, W = size x = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) weight = torch.randn(C, device="cuda", dtype=dtype, requires_grad=True) @@ -159,7 +156,6 @@ def test_groupnorm_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() N, C, H, W = size diff --git a/benchmarks/python/test_huggingface_attn_bwd.py b/benchmarks/python/test_huggingface_attn_bwd.py index 30aac059567..b94c6a471c3 100644 --- a/benchmarks/python/test_huggingface_attn_bwd.py +++ b/benchmarks/python/test_huggingface_attn_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_attn_inputs, FLOAT_DTYPES, PROMOTE_DTYPES @@ -75,8 +75,6 @@ def test_huggingface_attn_bwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - batch_size, seq_len, nh, n_embd = size dropout_p = 0.2 @@ -118,7 +116,6 @@ def test_huggingface_attn_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() batch_size, seq_len, nh, n_embd = size diff --git a/benchmarks/python/test_huggingface_attn_fwd.py b/benchmarks/python/test_huggingface_attn_fwd.py index ba266798098..27a013a8481 100644 --- a/benchmarks/python/test_huggingface_attn_fwd.py +++ b/benchmarks/python/test_huggingface_attn_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_attn_inputs, FLOAT_DTYPES, PROMOTE_DTYPES @@ -104,8 +104,6 @@ def test_huggingface_attn_fwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - batch_size, seq_len, nh, n_embd = size dropout_p = 0.2 inputs = torch.randn(batch_size, nh, seq_len, seq_len, device="cuda", dtype=dtype) @@ -146,7 +144,6 @@ def test_huggingface_attn_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() batch_size, seq_len, nh, n_embd = size diff --git a/benchmarks/python/test_layernorm_bwd.py b/benchmarks/python/test_layernorm_bwd.py index 6e498198faf..154dc74d8b8 100644 --- a/benchmarks/python/test_layernorm_bwd.py +++ b/benchmarks/python/test_layernorm_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -117,8 +117,6 @@ def test_layernorm_bwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() - inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(*size, device="cuda", dtype=dtype) weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) @@ -157,7 +155,6 @@ def test_layernorm_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() diff --git a/benchmarks/python/test_layernorm_fwd.py b/benchmarks/python/test_layernorm_fwd.py index 3210cf54ed9..c6a5f24c8dc 100644 --- a/benchmarks/python/test_layernorm_fwd.py +++ b/benchmarks/python/test_layernorm_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -84,8 +84,6 @@ def test_layernorm_fwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() - batch_size, hidden_size = size inputs = [ torch.randn(size, device="cuda", dtype=dtype), @@ -117,7 +115,6 @@ def test_layernorm_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() batch_size, hidden_size = size diff --git a/benchmarks/python/test_many_pointwise_ops.py b/benchmarks/python/test_many_pointwise_ops.py index 7dc600858a9..d62fb67b26a 100644 --- a/benchmarks/python/test_many_pointwise_ops.py +++ b/benchmarks/python/test_many_pointwise_ops.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark import torch from .global_params import PROMOTE_DTYPES @@ -39,8 +39,6 @@ def test_pointwise_ops_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = [torch.randn(13, device="cuda", dtype=torch.float16) for _ in range(2)] # Generate multiple inputs to measure dynamic shape overhead. diff --git a/benchmarks/python/test_matmul.py b/benchmarks/python/test_matmul.py index 9c003343e4d..865caba2e31 100644 --- a/benchmarks/python/test_matmul.py +++ b/benchmarks/python/test_matmul.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition -from nvfuser.pytorch_utils import clear_cuda_cache from .core import run_benchmark import torch @@ -42,29 +41,23 @@ def test_matmul_baseline_benchmark( ): m, n, k, layout = config - clear_cuda_cache() - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = half_reduction torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = half_reduction - try: - a = torch.randn(m, k, device="cuda", dtype=dtype) - b = torch.randn(k, n, device="cuda", dtype=dtype) - - if layout == "NT" or layout == "NN": - a = a.as_strided(size=[m, k], stride=[1, m]) - if layout == "TN" or layout == "NN": - b = b.as_strided(size=[k, n], stride=[1, k]) + a = torch.randn(m, k, device="cuda", dtype=dtype) + b = torch.randn(k, n, device="cuda", dtype=dtype) - # NOTE: we never need to validate eager, as it is our baseline - run_benchmark( - benchmark, - lambda ab: torch.matmul(*ab), - [a, b], - ) + if layout == "NT" or layout == "NN": + a = a.as_strided(size=[m, k], stride=[1, m]) + if layout == "TN" or layout == "NN": + b = b.as_strided(size=[k, n], stride=[1, k]) - except torch.OutOfMemoryError: - pytest.skip("Test failed due to OutOfMemoryError") + # NOTE: we never need to validate eager, as it is our baseline + run_benchmark( + benchmark, + lambda ab: torch.matmul(*ab), + [a, b], + ) @pytest.mark.parametrize("half_reduction", [False, True], ids=["fullred", "halfred"]) @@ -82,8 +75,6 @@ def test_matmul_nvf_benchmark( ): m, n, k, layout = config - clear_cuda_cache() - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = half_reduction torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = half_reduction @@ -91,24 +82,20 @@ def test_matmul_nvf_benchmark( # See https://github.com/NVIDIA/Fuser/pull/1719 pytest.skip("Reduced precision reduction not implemented in nvFuser") - try: - a = torch.randn(m, k, device="cuda", dtype=dtype) - b = torch.randn(k, n, device="cuda", dtype=dtype) - - if layout == "NT" or layout == "NN": - a = a.as_strided(size=[m, k], stride=[1, m]) - if layout == "TN" or layout == "NN": - b = b.as_strided(size=[k, n], stride=[1, k]) + a = torch.randn(m, k, device="cuda", dtype=dtype) + b = torch.randn(k, n, device="cuda", dtype=dtype) - with FusionDefinition() as fd: - matmul_fusion(fd, [a, b]) + if layout == "NT" or layout == "NN": + a = a.as_strided(size=[m, k], stride=[1, m]) + if layout == "TN" or layout == "NN": + b = b.as_strided(size=[k, n], stride=[1, k]) - if not disable_validation: - eager_output = torch.matmul(a, b) - fd.validate([a, b], [eager_output]) + with FusionDefinition() as fd: + matmul_fusion(fd, [a, b]) - if not disable_benchmarking: - run_benchmark(benchmark, fd.execute, [a, b]) + if not disable_validation: + eager_output = torch.matmul(a, b) + fd.validate([a, b], [eager_output]) - except torch.OutOfMemoryError: - pytest.skip("Test failed due to OutOfMemoryError") + if not disable_benchmarking: + run_benchmark(benchmark, fd.execute, [a, b]) diff --git a/benchmarks/python/test_nanogpt_attn_bwd.py b/benchmarks/python/test_nanogpt_attn_bwd.py index 7a23c6681fa..136429d475e 100644 --- a/benchmarks/python/test_nanogpt_attn_bwd.py +++ b/benchmarks/python/test_nanogpt_attn_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_attn_inputs, FLOAT_DTYPES, PROMOTE_DTYPES @@ -91,7 +91,6 @@ def test_nanogpt_attn_bwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() batch_size, seq_len, nh, n_embd = size hs = n_embd // nh dropout_p = 0.2 @@ -134,7 +133,6 @@ def test_nanogpt_attn_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() batch_size, seq_len, nh, n_embd = size diff --git a/benchmarks/python/test_nanogpt_attn_fwd.py b/benchmarks/python/test_nanogpt_attn_fwd.py index c8d8ad06dbe..4dbd5821c59 100644 --- a/benchmarks/python/test_nanogpt_attn_fwd.py +++ b/benchmarks/python/test_nanogpt_attn_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_attn_inputs, FLOAT_DTYPES, PROMOTE_DTYPES @@ -102,7 +102,6 @@ def test_nanogpt_attn_fwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() batch_size, seq_len, nh, n_embd = size hs = n_embd // nh dropout_p = 0.2 @@ -147,7 +146,6 @@ def test_nanogpt_attn_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() batch_size, seq_len, nh, n_embd = size diff --git a/benchmarks/python/test_pointwise_mul.py b/benchmarks/python/test_pointwise_mul.py index cba7c54c433..0162950cc47 100644 --- a/benchmarks/python/test_pointwise_mul.py +++ b/benchmarks/python/test_pointwise_mul.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -37,8 +37,6 @@ def test_pointwise_mul_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = [torch.randn(size, device="cuda", dtype=dtype)] with FusionDefinition() as fd: @@ -61,7 +59,6 @@ def test_pointwise_mul_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() input = torch.randn(size, device="cuda", dtype=dtype) diff --git a/benchmarks/python/test_reduction.py b/benchmarks/python/test_reduction.py index 4f4cbe023b3..f734769a1e5 100644 --- a/benchmarks/python/test_reduction.py +++ b/benchmarks/python/test_reduction.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -40,8 +40,6 @@ def test_reduction_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = [torch.randn(*size, device="cuda", dtype=dtype)] with FusionDefinition() as fd: @@ -66,7 +64,6 @@ def test_reduction_baseline_benchmark( reduction_axis: int, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() input = torch.randn(size, device="cuda", dtype=dtype) diff --git a/benchmarks/python/test_reduction_epilogue.py b/benchmarks/python/test_reduction_epilogue.py index 9e73bf67ac6..231090e4135 100644 --- a/benchmarks/python/test_reduction_epilogue.py +++ b/benchmarks/python/test_reduction_epilogue.py @@ -4,7 +4,7 @@ import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -50,7 +50,6 @@ def test_reduction_epilogue_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() x = torch.randn(size, device="cuda", dtype=dtype) epilogue = torch.randn(size[reduction_axis - 1], device="cuda", dtype=dtype) with FusionDefinition() as fd: @@ -79,7 +78,6 @@ def test_reduction_epilogue_baseline_benchmark( reduction_axis: int, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() x = torch.randn(size, device="cuda", dtype=dtype) diff --git a/benchmarks/python/test_rmsnorm_bwd.py b/benchmarks/python/test_rmsnorm_bwd.py index 3b583fb4fd7..3076dd826bb 100644 --- a/benchmarks/python/test_rmsnorm_bwd.py +++ b/benchmarks/python/test_rmsnorm_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -91,7 +91,6 @@ def test_rmsnorm_bwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(size, device="cuda", dtype=dtype) weights = torch.randn(size[1], device="cuda", dtype=dtype, requires_grad=True) @@ -122,7 +121,6 @@ def test_rmsnorm_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) diff --git a/benchmarks/python/test_rmsnorm_fwd.py b/benchmarks/python/test_rmsnorm_fwd.py index 24873bd7787..b7839b631de 100644 --- a/benchmarks/python/test_rmsnorm_fwd.py +++ b/benchmarks/python/test_rmsnorm_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -70,8 +70,6 @@ def test_rmsnorm_fwd_nvf_benchmark( disable_benchmarking: bool, eps: float = 1e-5, ): - clear_cuda_cache() - inputs = torch.randn(size, device="cuda", dtype=dtype) weights = torch.randn(size[1], device="cuda", dtype=dtype) @@ -97,7 +95,6 @@ def test_rmsnorm_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = torch.randn(size, device="cuda", dtype=dtype) diff --git a/benchmarks/python/test_rope.py b/benchmarks/python/test_rope.py index a29d4cceeb0..5df43bdc6ea 100644 --- a/benchmarks/python/test_rope.py +++ b/benchmarks/python/test_rope.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import clear_cuda_cache from .core import run_benchmark import torch @@ -128,8 +127,6 @@ def rope_without_cat_fusion( def test_rope_benchmark( benchmark, use_cat: bool, disable_validation: bool, disable_benchmarking: bool ): - clear_cuda_cache() - batch_size = 32 seq_len = 4096 num_heads = 32 diff --git a/benchmarks/python/test_scale_bias_relu_bwd.py b/benchmarks/python/test_scale_bias_relu_bwd.py index 4197a6fbc4e..f2c75ef3971 100644 --- a/benchmarks/python/test_scale_bias_relu_bwd.py +++ b/benchmarks/python/test_scale_bias_relu_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -61,8 +61,6 @@ def test_sbr_bwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(*size, device="cuda", dtype=dtype) scale = torch.ones(size[-1], device="cuda", dtype=dtype) @@ -90,7 +88,6 @@ def test_sbr_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) diff --git a/benchmarks/python/test_scale_bias_relu_fwd.py b/benchmarks/python/test_scale_bias_relu_fwd.py index 6a824650fa3..ede13dbb767 100644 --- a/benchmarks/python/test_scale_bias_relu_fwd.py +++ b/benchmarks/python/test_scale_bias_relu_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -67,8 +67,6 @@ def test_sbr_fwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) bias = torch.ones(size[-1], device="cuda", dtype=dtype) scale = torch.ones(size[-1], device="cuda", dtype=dtype) @@ -93,7 +91,6 @@ def test_sbr_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) diff --git a/benchmarks/python/test_silu_mul_bwd.py b/benchmarks/python/test_silu_mul_bwd.py index 53ce2d2528b..17fc57587cd 100644 --- a/benchmarks/python/test_silu_mul_bwd.py +++ b/benchmarks/python/test_silu_mul_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache, unary_bwd_torch import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -65,7 +65,6 @@ def test_silu_mul_bwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() x = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) y = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) grads = torch.randn(*size, device="cuda", dtype=dtype) @@ -89,7 +88,6 @@ def test_silu_mul_bwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() x = torch.randn(*size, device="cuda", dtype=dtype, requires_grad=True) diff --git a/benchmarks/python/test_silu_mul_fwd.py b/benchmarks/python/test_silu_mul_fwd.py index 667174588e5..0f1e86d0d56 100644 --- a/benchmarks/python/test_silu_mul_fwd.py +++ b/benchmarks/python/test_silu_mul_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -44,7 +44,6 @@ def test_silu_mul_fwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() inputs = [torch.randn(*size, device="cuda", dtype=dtype) for _ in range(2)] with FusionDefinition() as fd: @@ -66,7 +65,6 @@ def test_silu_mul_fwd_baseline_benchmark( dtype: torch.dtype, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() inputs = [torch.randn(*size, device="cuda", dtype=dtype) for _ in range(2)] diff --git a/benchmarks/python/test_softmax_bwd.py b/benchmarks/python/test_softmax_bwd.py index dfa9d2907d5..085268a405d 100644 --- a/benchmarks/python/test_softmax_bwd.py +++ b/benchmarks/python/test_softmax_bwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES @@ -79,8 +79,6 @@ def test_softmax_bwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = [ torch.randn(size, device="cuda", dtype=dtype, requires_grad=True), torch.randn(size, device="cuda", dtype=dtype), @@ -109,7 +107,6 @@ def test_softmax_bwd_baseline_benchmark( reduction_axis: int, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() input = torch.randn(size, device="cuda", dtype=dtype, requires_grad=True) diff --git a/benchmarks/python/test_softmax_fwd.py b/benchmarks/python/test_softmax_fwd.py index a74f1e5cf7b..2e672eb2e30 100644 --- a/benchmarks/python/test_softmax_fwd.py +++ b/benchmarks/python/test_softmax_fwd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -68,8 +68,6 @@ def test_softmax_fwd_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - inputs = [torch.randn(size, device="cuda", dtype=dtype)] with FusionDefinition() as fd: @@ -94,7 +92,6 @@ def test_softmax_fwd_baseline_benchmark( reduction_axis: int, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() input = torch.randn(size, device="cuda", dtype=dtype) diff --git a/benchmarks/python/test_transpose.py b/benchmarks/python/test_transpose.py index d72a97d6177..cf290f278a5 100644 --- a/benchmarks/python/test_transpose.py +++ b/benchmarks/python/test_transpose.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype, clear_cuda_cache +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype from .core import run_benchmark, clear_dynamo_cache import torch from .global_params import generate_input_sizes, FLOAT_DTYPES, PROMOTE_DTYPES @@ -55,8 +55,6 @@ def test_transpose_nvf_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - input1 = torch.randn(size, device="cuda", dtype=dtype) input2 = torch.randn(size, device="cuda", dtype=dtype) permute_axes = list(range(len(size))) @@ -87,7 +85,6 @@ def test_transpose_baseline_benchmark( axes: list, compile: bool, ): - clear_cuda_cache() if compile: clear_dynamo_cache() input1 = torch.randn(size, device="cuda", dtype=dtype) diff --git a/nvfuser/pytorch_utils.py b/nvfuser/pytorch_utils.py index 75fb4a6acc8..7ad7c0c3e26 100644 --- a/nvfuser/pytorch_utils.py +++ b/nvfuser/pytorch_utils.py @@ -2,12 +2,14 @@ # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause import torch -from typing import Type, Union, Tuple -import ctypes -import gc from ._C import DataType +import ctypes +import functools +import gc +from typing import Type, Union, Tuple + NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]] _torch_dtype_to_nvfuser_dtype_map = { @@ -169,13 +171,31 @@ def get_device_properties() -> Tuple[int, float]: DEVICE_PROPERTIES = get_device_properties() -def clear_cuda_cache() -> None: - """ - Utility function to clear CUDA cache before running a test. - """ - if ( - torch.cuda.memory_allocated() - or torch.cuda.memory_reserved() > 0.8 * DEVICE_PROPERTIES["gpu_gmem_bytes"] - ): +def retry_on_oom_or_skip_test(func): + """Decorator: upon torch.OutOfMemoryError clear the cache and retry test""" + + @functools.wraps(func) + def retried_func(*args, **kwargs): + try: + output = func(*args, **kwargs) + except torch.OutOfMemoryError: + pass + else: + return output + + # We have hit an OOM error, so clear the cache and retry gc.collect() torch.cuda.empty_cache() + + try: + output = func(*args, **kwargs) + except torch.OutOfMemoryError as e: + # If we hit an OOM this time, then skip the test + import pytest + + pytest.skip(f"Test failed due to OutOfMemoryError: {e}") + return + + return output + + return retried_func diff --git a/tests/python/test_ops.py b/tests/python/test_ops.py index c4607c107b5..f5c8a57dcab 100644 --- a/tests/python/test_ops.py +++ b/tests/python/test_ops.py @@ -16,7 +16,7 @@ from typing import Callable from nvfuser import FusionCache, FusionDefinition -from nvfuser.pytorch_utils import clear_cuda_cache +from nvfuser.pytorch_utils import retry_on_oom_or_skip_test from utils import ( check_captured_python_definition, @@ -200,8 +200,8 @@ def correctness_test_fn( # Run serde check for each operation and dtype but not for each sample input. # NOTE: Disabled serde_check_ops decorator to avoid CI timeout. +@retry_on_oom_or_skip_test def serde_test_fn(op: OpInfo, dtype: torch.dtype): - clear_cuda_cache() for sample in op.sample_input_generator(op, dtype): result = correctness_test_fn(op.reference_type, op, sample) if result is not None: @@ -241,11 +241,11 @@ def schedule(self): # TODO Maybe only test a single dtype @create_op_test(tuple(op for op in opinfos if op.sample_input_generator is not None)) +@retry_on_oom_or_skip_test def test_definition_op_in_schedule_error(op: OpInfo, dtype: torch.dtype): for sample in op.sample_input_generator(op, dtype): # clear cache for each sample FusionCache.reset() - clear_cuda_cache() with pytest.raises( RuntimeError, match=r"Attempting to add to a completed definition" ): @@ -277,8 +277,8 @@ def _regex_escape_parenthesis(a: str) -> str: @create_op_test(tuple(op for op in opinfos if op.error_input_generator is not None)) +@retry_on_oom_or_skip_test def test_errors(op: OpInfo, dtype: torch.dtype): - clear_cuda_cache() for sample, exception_type, exception_regex in op.error_input_generator(op, dtype): with pytest.raises( exception_type, match=_regex_escape_parenthesis(exception_regex) diff --git a/version.txt b/version.txt index 599028f5f5f..109a20f1d13 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.2.21 +0.2.22 From 7a3b1a4725adf095a301b54dec93f01f01173e9c Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 29 Oct 2024 08:27:08 -0700 Subject: [PATCH 08/17] Refactor MultiMatmulSchedulers (#3277) The PR fixes #3266. * Future consolidation is possible but I kept some duplicate functions in `HopperMultiMatrixScheduler` and `AmpereMultiMatrixScheduler` for flexibility. ### Changes: - Move `isPowOf2` to `csrc/utils.h` - Move `representativeId` to `scheduler/tools/abstract_tensor.h` - Move `checkConcreteStaticDim to mma_utils.cpp` - Add TODO to remove `swizzleSharedMemory` from `HopperMultiMatrixScheduler` - Create base class `MultiMatrixScheduler` to hold common functions like `findPatterns` ### Details: - Create base class `MultiMatrixScheduler` - `HopperMultiMatrixScheduler` and `AmpereMultiMatrixScheduler` inherit from `MultiMatrixScheduler` and overwrite the `run` function. - `MultiMatrixScheduler` implements `findPatterns`, `translatePatterns`, `findRoles`, `countDims`, and `updateIdModel`. It also holds the necessary data members for those functions. --- csrc/scheduler/ampere_multi_matmul.cpp | 189 +------------------------ csrc/scheduler/ampere_multi_matmul.h | 49 +------ csrc/scheduler/hopper_multi_matmul.cpp | 188 +----------------------- csrc/scheduler/hopper_multi_matmul.h | 49 +------ csrc/scheduler/mma_utils.h | 13 ++ csrc/scheduler/multi_matmul.cpp | 87 ++++++++++++ csrc/scheduler/multi_matmul.h | 54 +++++++ csrc/scheduler/tools/abstract_tensor.h | 10 +- csrc/utils.h | 5 + 9 files changed, 194 insertions(+), 450 deletions(-) diff --git a/csrc/scheduler/ampere_multi_matmul.cpp b/csrc/scheduler/ampere_multi_matmul.cpp index 622b98519bd..208c6077c8e 100644 --- a/csrc/scheduler/ampere_multi_matmul.cpp +++ b/csrc/scheduler/ampere_multi_matmul.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -19,6 +18,7 @@ #include #include #include +#include #include #include @@ -31,32 +31,6 @@ namespace nvfuser { namespace { -// Returns true if given number is power of 2 -constexpr bool isPowOf2(int64_t x) { - return x > 1 && (x & (x - 1)) == 0; -} - -inline IterDomain* representativeId(const AbstractId& abs_id) { - if (abs_id.is()) { - return abs_id.as(); - } - NVF_ERROR(abs_id.is()); - return representativeId(abs_id.as().group); -} - -// Utility to check concrete static size -inline void checkConcreteStaticDim(const AbstractId& abs_id) { - IterDomain* id = representativeId(abs_id); - NVF_ERROR( - !id->isBroadcast() && !id->isReduction(), - "no support for reduction or broadcast domains, but got ", - id->toString()); - NVF_ERROR( - id->extent()->isConstInt(), - "swizzled dimension's extend must be known during scheduling, got ", - id->toString()); -} - //! Automatically generates the shared memory swizzled data layout //! for matmul mainloop and epilogue. //! The shared mem data layout is always 2D currently, and this utility @@ -76,8 +50,8 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { (int64_t)swizzle_domain.size() >= 2, "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", shared_mem_tv->toString()); - checkConcreteStaticDim(swizzle_domain[-2]); - checkConcreteStaticDim(swizzle_domain[-1]); + mma_utils::checkConcreteStaticDim(swizzle_domain[-2]); + mma_utils::checkConcreteStaticDim(swizzle_domain[-1]); // Extract the constant sizes of the swizzled tile const int64_t tile_size_x = @@ -522,98 +496,6 @@ void AmpereMultipleMatmulScheduler::cacheInputsAndOutputs() { scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true); } -void AmpereMultipleMatmulScheduler::findPatterns() { - patterns_ = mma_utils::findMatmulPatterns(fusion_); - NVF_ERROR(!patterns_.empty(), "No matmul patterns were found"); -} - -void AmpereMultipleMatmulScheduler::countDims() { - NVF_ERROR(!patterns_.empty()); - TensorView* mma_result = patterns_.front().output; - num_device_dims_ = numDeviceDims(mma_result); - for (const auto& it : id_roles_) { - if (it.second == MatmulDimRole::Batch && - // Skip device dims - !std::any_of(it.first->begin(), it.first->end(), [](Val* v) { - return v->as()->isDeviceDim(); - })) { - // All batch dims will be merged into one, if any exist - num_local_batch_dims_ = 1; - } - } - num_splitk_dims_ = params_->splitk_factor > 1 ? 1 : 0; - // Subtract 6 for the [Mo, No, Ko, Mi, Ni, Ki] - num_device_and_batch_dims_ = num_device_dims_ + num_local_batch_dims_; -} - -void AmpereMultipleMatmulScheduler::translatePatterns() { - mma_results_.reserve(patterns_.size()); - for (mma_utils::MatmulPattern& pattern : patterns_) { - MmaOp* mma = pattern.translateToMmaOp(); - mma_results_.push_back(mma->out()->as()); - } - - // Build IdModel graphs now since translateToMmaOp creates new TVs. Before - // this point the graphs are not yet built. - updateIdModel(); -} - -// Get tensor roles and id roles -// When there are multiple matmul patterns, we can have conflicting roles. -// For now we throw an error if this is the case. -// TODO: This should be checked in canScheduleCompileTime -void AmpereMultipleMatmulScheduler::findRoles() { - const auto roles_opt = mma_utils::allPatternRoles(id_model_, patterns_); - NVF_ERROR( - roles_opt.has_value(), - "Incompatible roles found between matmul patterns"); - std::tie(id_roles_, tensor_roles_) = roles_opt.value(); - - mma_utils::MatmulOperandInnerDimsOpt inner_dims_opt = - mma_utils::getOperandInnerDims(id_model_, id_roles_, tensor_roles_); - NVF_ERROR(inner_dims_opt.isValid(), inner_dims_opt.getErrorMsg()); - inner_dims_ = inner_dims_opt.getData(); - - as_ = tensor_roles_.at(MatmulTensorRole::OPERAND_A); - bs_ = tensor_roles_.at(MatmulTensorRole::OPERAND_B); - - countDims(); -} - -// Including current tensor naming convention for reference, -// this is very temporary and will change over time and -// in fact the whole body of this function will -// eventually be a set of utility functions for different -// sections of matmul(fusion) kernels, with -// each having its own build out to do. -// -// Current naming convention is based on the following formula: -// -// d = alpha * (a x b) + beta * c -// -// and is defined in the following way: -// -// operands assumed in global memory : a, b, c -// -// registers staging global load : ar, br (short for a/b read) -// -// shared mem cache of operands : acw_smem, bcw_smem (short for a/b -// cache_write smem) -// -// registers at shared memory load output : acr, bcr (short for a/b cache -// read) -// -// register tensor input to the actual mma op: ab, bb (short for a/b -// broadcasted) -// -// accumulator register: mma_result -// - mma_result is MmaOp output if there is epilogue -// - mma_result is dc (short for d cache) if there is no epilogue -// -// result in global memory: d - -// Currently the support is for a, b, c and d as fusion inputs/outputs -// aka. no prolog fusion yet. void AmpereMultipleMatmulScheduler::defineOperandCaches() { cacheOperandsToSmem(as_, acw_smems_, params_->supported_vec_size.a); addSetsForCacheReads(acw_smems_, acrs_); @@ -669,12 +551,6 @@ void AmpereMultipleMatmulScheduler::cacheOperandsToSmem( } } -// We add two LoadStore operators to the inputs of our fusions. The first -// one is for a read from global memory and the second one (below) is for a -// cache read. As an optimizaton, we avoid adding an operator if there's an -// existing LoadStoreOp present. Please note that for the second LoadStore -// we don't propagate the allocation domain, since the scheduler sets the -// allocation domain in the registers. void AmpereMultipleMatmulScheduler::addSetsForCacheReads( const std::vector& tv_smems, std::vector& tv_rs) { @@ -702,38 +578,6 @@ void AmpereMultipleMatmulScheduler::addSetsForCacheReads( } } -//! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer -//! to the new IdModel. This is necessary whenever we perform an operation -//! that creates a new TensorView, such as caching or rFactor -void AmpereMultipleMatmulScheduler::updateIdModel() { - // Build new IdModel - IdModel new_id_model(fusion_, /*build_graphs=*/false); - new_id_model.buildPermissiveGraph(); - - // Get new permissive graph - ValGraph& new_graph = new_id_model.idGraph(IdMappingMode::PERMISSIVE); - - if (!id_roles_.empty()) { - // Update id_roles_ to have keys corresponding to ValGroups in the new - // IdModel - std::unordered_map new_id_roles; - for (auto& [k, v] : id_roles_) { - const ValGroup& new_group = new_graph.toGroup(k->front()); - new_id_roles.emplace(new_group, v); - } - id_roles_ = new_id_roles; - } - - graph_ = &new_id_model.idGraph(IdMappingMode::PERMISSIVE); - - // Set id_model_ after we are done using the old one - id_model_ = std::move(new_id_model); -} - -//! Swizzle the M and N outer dimensions after makeTile has been called. -//! This updates outer_dim_roles if we introduce a new dimension, which can -//! happen if tv is missing a merged axis, in which case we skip merging after -//! the split. This is analogous to forwarding during transform propagation. void AmpereMultipleMatmulScheduler::swizzleBlockTiles( TensorView* tv, std::vector& outer_dim_roles) { @@ -798,8 +642,6 @@ void AmpereMultipleMatmulScheduler::swizzleBlockTiles( } } -//! This calls orig->cacheAfter() and also updates the permissive graph to -//! reflect the new IterDomain mappings TensorView* AmpereMultipleMatmulScheduler::cacheAfter( TensorView* orig, LoadStoreOpType op_type, @@ -835,16 +677,6 @@ TensorView* AmpereMultipleMatmulScheduler::cacheAfter( return c; } -//! Do block tiling for a collection of TensorViews. The tensors should be -//! unscheduled before this method is called. -//! 1) Axes will be ordered according to canonicalDimOrdering, and then axes -//! with the same role will be merged. -//! 2) After that, we perform splits according to -//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki]. -//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has -//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions -//! Mo and No. -//! 4) Finally, we do a split-K split if the splitk_factor is not 1 std::vector> AmpereMultipleMatmulScheduler:: blockTileTensors(const std::vector& tvs) { if (canonical_dim_ordering_.empty()) { @@ -920,10 +752,6 @@ std::vector> AmpereMultipleMatmulScheduler:: return all_merged_roles; } -//! Schedule the loads of all operands from global memory to shared memory. -//! Starting from the basic tiled schedule, we swizzle the operand memory. -//! Note that the cache op and LoadStoreOpType are already set during -//! defineOperandCaches(). void AmpereMultipleMatmulScheduler::scheduleOperandSmemStores() { auto scheduleBranch = [&](const std::vector& gmem_operands, const std::vector& smem_operands, @@ -989,8 +817,6 @@ void AmpereMultipleMatmulScheduler::scheduleMmaOperands( } } -// MmaOperand contains only A and B. If tvs are outputs (i.e. not operands), -// then operand_type should be std::nullopt. void AmpereMultipleMatmulScheduler::scheduleMmaResults() { auto all_merged_roles = blockTileTensors(mma_results_); for (size_t i : c10::irange(mma_results_.size())) { @@ -1238,8 +1064,8 @@ void AmpereMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) { const MatMulTileOptions& gemm_tile = params_->tile_sizes; const int64_t vectorization_factor = params_->supported_vec_size.epilogue; // input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n] - checkConcreteStaticDim(c->axis(-2)); - checkConcreteStaticDim(c->axis(-1)); + mma_utils::checkConcreteStaticDim(c->axis(-2)); + mma_utils::checkConcreteStaticDim(c->axis(-1)); const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as(); const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as(); NVF_ERROR( @@ -1360,9 +1186,6 @@ void AmpereMultipleMatmulScheduler::scheduleEpilogue() { scheduleFusionInputsForEpilogue(); } -//! Propagates transformations from fusion output to fusion tv inputs that are -//! producers in the epilogue. Transformations' propagation aims at input tvs -//! which are not assigned to core roles, that is, are not MMA inputs. void AmpereMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() { std::vector cached_tvs; @@ -1463,8 +1286,6 @@ void AmpereMultipleMatmulScheduler::setUpInlining() { } } -// NOTE: this should be called after acw_smem, acr, ..., ab, and mma_result -// transforms have been applied and inlining void AmpereMultipleMatmulScheduler::setUpCircularBuffering() { // Propagate mma output swizzle and parallelization down the DAG if (params_->circular_buffer_options.circular_buffer_smem_write) { diff --git a/csrc/scheduler/ampere_multi_matmul.h b/csrc/scheduler/ampere_multi_matmul.h index 0c4f3264ba1..461a19ae302 100644 --- a/csrc/scheduler/ampere_multi_matmul.h +++ b/csrc/scheduler/ampere_multi_matmul.h @@ -8,9 +8,7 @@ #pragma once #include -#include -#include -#include +#include namespace nvfuser { @@ -66,13 +64,11 @@ namespace nvfuser { // Each of the named tensors above is scheduled differently. We schedule them // by building AbstractTensors for each tensor category; these are held in // AmpereMultipleMatmulScheduler::schedules_. -// TODO: Inheret from SchedulerEntry -class AmpereMultipleMatmulScheduler { +// TODO: Inherit from SchedulerEntry +class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler { public: AmpereMultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params) - : fusion_(fusion), - params_(params), - id_model_(fusion, /*build_graphs=*/false) { + : MultipleMatmulScheduler(fusion, params) { const auto device_prop = at::cuda::getCurrentDeviceProperties(); const int cc = device_prop->major * 10 + device_prop->minor; NVF_ERROR( @@ -80,23 +76,11 @@ class AmpereMultipleMatmulScheduler { "This matmul scheduler is restricted to Ampere and Turing."); } - void run(); + void run() final; private: void cacheInputsAndOutputs(); - void findPatterns(); - - void countDims(); - - void translatePatterns(); - - // Get tensor roles and id roles - // When there are multiple matmul patterns, we can have conflicting roles. - // For now we throw an error if this is the case. - // TODO: This should be checked in canScheduleCompileTime - void findRoles(); - // Including current tensor naming convention for reference, // this is very temporary and will change over time and // in fact the whole body of this function will @@ -148,11 +132,6 @@ class AmpereMultipleMatmulScheduler { const std::vector& tv_smems, std::vector& tv_rs); - //! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer - //! to the new IdModel. This is necessary whenever we perform an operation - //! that creates a new TensorView, such as caching or rFactor - void updateIdModel(); - //! Swizzle the M and N outer dimensions after makeTile has been called. //! This updates outer_dim_roles if we introduce a new dimension, which can //! happen if tv is missing a merged axis, in which case we skip merging after @@ -216,26 +195,12 @@ class AmpereMultipleMatmulScheduler { void setUpCircularBuffering(); private: - Fusion* fusion_; - const MatmulParams* params_; - IdModel id_model_; - // Permissive graph of id_model_, which we modify at times using e.g. - // AbstractTensor.split or by mapping vals in cacheAfter and rFactor - ValGraph* graph_ = nullptr; - std::vector patterns_; - mma_utils::DimRolesMap id_roles_; - mma_utils::TensorRolesMap tensor_roles_; - mma_utils::MatmulOperandInnerDims inner_dims_; - - int64_t num_splitk_dims_ = 0, num_device_dims_ = 0, num_local_batch_dims_ = 0, - num_device_and_batch_dims_ = 0; - std::vector> cached_outputs_; std::vector canonical_dim_ordering_; - std::vector as_, bs_, acw_smems_, bcw_smems_, acrs_, bcrs_, abs_, - bbs_, mma_results_, splitk_sums_, smem_epilogues_; + std::vector acw_smems_, bcw_smems_, acrs_, bcrs_, abs_, bbs_, + splitk_sums_, smem_epilogues_; }; } // namespace nvfuser diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 0b44425b10e..6f02148ce93 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -19,6 +18,7 @@ #include #include #include +#include #include #include @@ -31,32 +31,6 @@ namespace nvfuser { namespace { -// Returns true if given number is power of 2 -constexpr bool isPowOf2(int64_t x) { - return x > 1 && (x & (x - 1)) == 0; -} - -inline IterDomain* representativeId(const AbstractId& abs_id) { - if (abs_id.is()) { - return abs_id.as(); - } - NVF_ERROR(abs_id.is()); - return representativeId(abs_id.as().group); -} - -// Utility to check concrete static size -inline void checkConcreteStaticDim(const AbstractId& abs_id) { - IterDomain* id = representativeId(abs_id); - NVF_ERROR( - !id->isBroadcast() && !id->isReduction(), - "no support for reduction or broadcast domains, but got ", - id->toString()); - NVF_ERROR( - id->extent()->isConstInt(), - "swizzled dimension's extend must be known during scheduling, got ", - id->toString()); -} - //! Automatically generates the shared memory swizzled data layout //! for matmul mainloop and epilogue. //! The shared mem data layout is always 2D currently, and this utility @@ -65,6 +39,8 @@ inline void checkConcreteStaticDim(const AbstractId& abs_id) { //! Returns the domain with swizzle. For the case of legacy swizzle, this //! domain must be set as loop domain. For the case of new swizzle, this domain //! must be set as allocation domain. +//! +//! TODO: Refactor this for TMA loads template AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { NVF_ERROR(shared_mem_tv->getMemoryType() == MemoryType::Shared); @@ -76,8 +52,8 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) { (int64_t)swizzle_domain.size() >= 2, "At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ", shared_mem_tv->toString()); - checkConcreteStaticDim(swizzle_domain[-2]); - checkConcreteStaticDim(swizzle_domain[-1]); + mma_utils::checkConcreteStaticDim(swizzle_domain[-2]); + mma_utils::checkConcreteStaticDim(swizzle_domain[-1]); // Extract the constant sizes of the swizzled tile const int64_t tile_size_x = @@ -522,98 +498,6 @@ void HopperMultipleMatmulScheduler::cacheInputsAndOutputs() { scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true); } -void HopperMultipleMatmulScheduler::findPatterns() { - patterns_ = mma_utils::findMatmulPatterns(fusion_); - NVF_ERROR(!patterns_.empty(), "No matmul patterns were found"); -} - -void HopperMultipleMatmulScheduler::countDims() { - NVF_ERROR(!patterns_.empty()); - TensorView* mma_result = patterns_.front().output; - num_device_dims_ = numDeviceDims(mma_result); - for (const auto& it : id_roles_) { - if (it.second == MatmulDimRole::Batch && - // Skip device dims - !std::any_of(it.first->begin(), it.first->end(), [](Val* v) { - return v->as()->isDeviceDim(); - })) { - // All batch dims will be merged into one, if any exist - num_local_batch_dims_ = 1; - } - } - num_splitk_dims_ = params_->splitk_factor > 1 ? 1 : 0; - // Subtract 6 for the [Mo, No, Ko, Mi, Ni, Ki] - num_device_and_batch_dims_ = num_device_dims_ + num_local_batch_dims_; -} - -void HopperMultipleMatmulScheduler::translatePatterns() { - mma_results_.reserve(patterns_.size()); - for (mma_utils::MatmulPattern& pattern : patterns_) { - MmaOp* mma = pattern.translateToMmaOp(); - mma_results_.push_back(mma->out()->as()); - } - - // Build IdModel graphs now since translateToMmaOp creates new TVs. Before - // this point the graphs are not yet built. - updateIdModel(); -} - -// Get tensor roles and id roles -// When there are multiple matmul patterns, we can have conflicting roles. -// For now we throw an error if this is the case. -// TODO: This should be checked in canScheduleCompileTime -void HopperMultipleMatmulScheduler::findRoles() { - const auto roles_opt = mma_utils::allPatternRoles(id_model_, patterns_); - NVF_ERROR( - roles_opt.has_value(), - "Incompatible roles found between matmul patterns"); - std::tie(id_roles_, tensor_roles_) = roles_opt.value(); - - mma_utils::MatmulOperandInnerDimsOpt inner_dims_opt = - mma_utils::getOperandInnerDims(id_model_, id_roles_, tensor_roles_); - NVF_ERROR(inner_dims_opt.isValid(), inner_dims_opt.getErrorMsg()); - inner_dims_ = inner_dims_opt.getData(); - - as_ = tensor_roles_.at(MatmulTensorRole::OPERAND_A); - bs_ = tensor_roles_.at(MatmulTensorRole::OPERAND_B); - - countDims(); -} - -// Including current tensor naming convention for reference, -// this is very temporary and will change over time and -// in fact the whole body of this function will -// eventually be a set of utility functions for different -// sections of matmul(fusion) kernels, with -// each having its own build out to do. -// -// Current naming convention is based on the following formula: -// -// d = alpha * (a x b) + beta * c -// -// and is defined in the following way: -// -// operands assumed in global memory : a, b, c -// -// registers staging global load : ar, br (short for a/b read) -// -// shared mem cache of operands : acw_smem, bcw_smem (short for a/b -// cache_write smem) -// -// registers at shared memory load output : acr, bcr (short for a/b cache -// read) -// -// register tensor input to the actual mma op: ab, bb (short for a/b -// broadcasted) -// -// accumulator register: mma_result -// - mma_result is MmaOp output if there is epilogue -// - mma_result is dc (short for d cache) if there is no epilogue -// -// result in global memory: d - -// Currently the support is for a, b, c and d as fusion inputs/outputs -// aka. no prolog fusion yet. void HopperMultipleMatmulScheduler::defineOperandCaches() { cacheOperandsToSmem(as_, acw_smems_, params_->supported_vec_size.a); addSetsForCacheReads(acw_smems_, acrs_); @@ -669,12 +553,6 @@ void HopperMultipleMatmulScheduler::cacheOperandsToSmem( } } -// We add two LoadStore operators to the inputs of our fusions. The first -// one is for a read from global memory and the second one (below) is for a -// cache read. As an optimizaton, we avoid adding an operator if there's an -// existing LoadStoreOp present. Please note that for the second LoadStore -// we don't propagate the allocation domain, since the scheduler sets the -// allocation domain in the registers. void HopperMultipleMatmulScheduler::addSetsForCacheReads( const std::vector& tv_smems, std::vector& tv_rs) { @@ -702,38 +580,6 @@ void HopperMultipleMatmulScheduler::addSetsForCacheReads( } } -//! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer -//! to the new IdModel. This is necessary whenever we perform an operation -//! that creates a new TensorView, such as caching or rFactor -void HopperMultipleMatmulScheduler::updateIdModel() { - // Build new IdModel - IdModel new_id_model(fusion_, /*build_graphs=*/false); - new_id_model.buildPermissiveGraph(); - - // Get new permissive graph - ValGraph& new_graph = new_id_model.idGraph(IdMappingMode::PERMISSIVE); - - if (!id_roles_.empty()) { - // Update id_roles_ to have keys corresponding to ValGroups in the new - // IdModel - std::unordered_map new_id_roles; - for (auto& [k, v] : id_roles_) { - const ValGroup& new_group = new_graph.toGroup(k->front()); - new_id_roles.emplace(new_group, v); - } - id_roles_ = new_id_roles; - } - - graph_ = &new_id_model.idGraph(IdMappingMode::PERMISSIVE); - - // Set id_model_ after we are done using the old one - id_model_ = std::move(new_id_model); -} - -//! Swizzle the M and N outer dimensions after makeTile has been called. -//! This updates outer_dim_roles if we introduce a new dimension, which can -//! happen if tv is missing a merged axis, in which case we skip merging after -//! the split. This is analogous to forwarding during transform propagation. void HopperMultipleMatmulScheduler::swizzleBlockTiles( TensorView* tv, std::vector& outer_dim_roles) { @@ -798,8 +644,6 @@ void HopperMultipleMatmulScheduler::swizzleBlockTiles( } } -//! This calls orig->cacheAfter() and also updates the permissive graph to -//! reflect the new IterDomain mappings TensorView* HopperMultipleMatmulScheduler::cacheAfter( TensorView* orig, LoadStoreOpType op_type, @@ -835,16 +679,6 @@ TensorView* HopperMultipleMatmulScheduler::cacheAfter( return c; } -//! Do block tiling for a collection of TensorViews. The tensors should be -//! unscheduled before this method is called. -//! 1) Axes will be ordered according to canonicalDimOrdering, and then axes -//! with the same role will be merged. -//! 2) After that, we perform splits according to -//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki]. -//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has -//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions -//! Mo and No. -//! 4) Finally, we do a split-K split if the splitk_factor is not 1 std::vector> HopperMultipleMatmulScheduler:: blockTileTensors(const std::vector& tvs) { if (canonical_dim_ordering_.empty()) { @@ -920,10 +754,6 @@ std::vector> HopperMultipleMatmulScheduler:: return all_merged_roles; } -//! Schedule the loads of all operands from global memory to shared memory. -//! Starting from the basic tiled schedule, we swizzle the operand memory. -//! Note that the cache op and LoadStoreOpType are already set during -//! defineOperandCaches(). void HopperMultipleMatmulScheduler::scheduleOperandSmemStores() { auto scheduleBranch = [&](const std::vector& gmem_operands, const std::vector& smem_operands, @@ -989,8 +819,6 @@ void HopperMultipleMatmulScheduler::scheduleMmaOperands( } } -// MmaOperand contains only A and B. If tvs are outputs (i.e. not operands), -// then operand_type should be std::nullopt. void HopperMultipleMatmulScheduler::scheduleMmaResults() { auto all_merged_roles = blockTileTensors(mma_results_); for (size_t i : c10::irange(mma_results_.size())) { @@ -1238,8 +1066,8 @@ void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) { const MatMulTileOptions& gemm_tile = params_->tile_sizes; const int64_t vectorization_factor = params_->supported_vec_size.epilogue; // input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n] - checkConcreteStaticDim(c->axis(-2)); - checkConcreteStaticDim(c->axis(-1)); + mma_utils::checkConcreteStaticDim(c->axis(-2)); + mma_utils::checkConcreteStaticDim(c->axis(-1)); const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as(); const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as(); NVF_ERROR( @@ -1463,8 +1291,6 @@ void HopperMultipleMatmulScheduler::setUpInlining() { } } -// NOTE: this should be called after acw_smem, acr, ..., ab, and mma_result -// transforms have been applied and inlining void HopperMultipleMatmulScheduler::setUpCircularBuffering() { // Propagate mma output swizzle and parallelization down the DAG if (params_->circular_buffer_options.circular_buffer_smem_write) { diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index d803766c12e..08e8a110fa6 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -8,9 +8,7 @@ #pragma once #include -#include -#include -#include +#include namespace nvfuser { @@ -66,36 +64,22 @@ namespace nvfuser { // Each of the named tensors above is scheduled differently. We schedule them // by building AbstractTensors for each tensor category; these are held in // HopperMultipleMatmulScheduler::schedules_. -// TODO: Inheret from SchedulerEntry -class HopperMultipleMatmulScheduler { +// TODO: Inherit from SchedulerEntry +class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { public: HopperMultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params) - : fusion_(fusion), - params_(params), - id_model_(fusion, /*build_graphs=*/false) { + : MultipleMatmulScheduler(fusion, params) { const auto device_prop = at::cuda::getCurrentDeviceProperties(); const int cc = device_prop->major * 10 + device_prop->minor; NVF_ERROR( cc >= 90 && cc < 100, "This matmul scheduler is restricted to Hopper."); } - void run(); + void run() final; private: void cacheInputsAndOutputs(); - void findPatterns(); - - void countDims(); - - void translatePatterns(); - - // Get tensor roles and id roles - // When there are multiple matmul patterns, we can have conflicting roles. - // For now we throw an error if this is the case. - // TODO: This should be checked in canScheduleCompileTime - void findRoles(); - // Including current tensor naming convention for reference, // this is very temporary and will change over time and // in fact the whole body of this function will @@ -147,11 +131,6 @@ class HopperMultipleMatmulScheduler { const std::vector& tv_smems, std::vector& tv_rs); - //! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer - //! to the new IdModel. This is necessary whenever we perform an operation - //! that creates a new TensorView, such as caching or rFactor - void updateIdModel(); - //! Swizzle the M and N outer dimensions after makeTile has been called. //! This updates outer_dim_roles if we introduce a new dimension, which can //! happen if tv is missing a merged axis, in which case we skip merging after @@ -215,26 +194,12 @@ class HopperMultipleMatmulScheduler { void setUpCircularBuffering(); private: - Fusion* fusion_; - const MatmulParams* params_; - IdModel id_model_; - // Permissive graph of id_model_, which we modify at times using e.g. - // AbstractTensor.split or by mapping vals in cacheAfter and rFactor - ValGraph* graph_ = nullptr; - std::vector patterns_; - mma_utils::DimRolesMap id_roles_; - mma_utils::TensorRolesMap tensor_roles_; - mma_utils::MatmulOperandInnerDims inner_dims_; - - int64_t num_splitk_dims_ = 0, num_device_dims_ = 0, num_local_batch_dims_ = 0, - num_device_and_batch_dims_ = 0; - std::vector> cached_outputs_; std::vector canonical_dim_ordering_; - std::vector as_, bs_, acw_smems_, bcw_smems_, acrs_, bcrs_, abs_, - bbs_, mma_results_, splitk_sums_, smem_epilogues_; + std::vector acw_smems_, bcw_smems_, acrs_, bcrs_, abs_, bbs_, + splitk_sums_, smem_epilogues_; }; } // namespace nvfuser diff --git a/csrc/scheduler/mma_utils.h b/csrc/scheduler/mma_utils.h index 0f9550189cb..25486c21824 100644 --- a/csrc/scheduler/mma_utils.h +++ b/csrc/scheduler/mma_utils.h @@ -462,6 +462,19 @@ std::optional> allPatternRoles( IdModel& id_model, const std::vector& patterns); +// Utility to check concrete static size +inline void checkConcreteStaticDim(const AbstractId& abs_id) { + IterDomain* id = representativeId(abs_id); + NVF_ERROR( + !id->isBroadcast() && !id->isReduction(), + "no support for reduction or broadcast domains, but got ", + id->toString()); + NVF_ERROR( + id->extent()->isConstInt(), + "swizzled dimension's extend must be known during scheduling, got ", + id->toString()); +} + } // namespace mma_utils std::string toString(const mma_utils::AbstractMatmulTensor& abten); diff --git a/csrc/scheduler/multi_matmul.cpp b/csrc/scheduler/multi_matmul.cpp index 0bd8a3cbd10..973fcf4dfb3 100644 --- a/csrc/scheduler/multi_matmul.cpp +++ b/csrc/scheduler/multi_matmul.cpp @@ -7,11 +7,98 @@ // clang-format on #include +#include #include #include namespace nvfuser { +void MultipleMatmulScheduler::findPatterns() { + patterns_ = mma_utils::findMatmulPatterns(fusion_); + NVF_ERROR(!patterns_.empty(), "No matmul patterns were found"); +} + +void MultipleMatmulScheduler::translatePatterns() { + mma_results_.reserve(patterns_.size()); + for (mma_utils::MatmulPattern& pattern : patterns_) { + MmaOp* mma = pattern.translateToMmaOp(); + mma_results_.push_back(mma->out()->as()); + } + + // Build IdModel graphs now since translateToMmaOp creates new TVs. Before + // this point the graphs are not yet built. + updateIdModel(); +} + +// Get tensor roles and id roles +// When there are multiple matmul patterns, we can have conflicting roles. +// For now we throw an error if this is the case. +// TODO: This should be checked in canScheduleCompileTime +void MultipleMatmulScheduler::findRoles() { + const auto roles_opt = mma_utils::allPatternRoles(id_model_, patterns_); + NVF_ERROR( + roles_opt.has_value(), + "Incompatible roles found between matmul patterns"); + std::tie(id_roles_, tensor_roles_) = roles_opt.value(); + + mma_utils::MatmulOperandInnerDimsOpt inner_dims_opt = + mma_utils::getOperandInnerDims(id_model_, id_roles_, tensor_roles_); + NVF_ERROR(inner_dims_opt.isValid(), inner_dims_opt.getErrorMsg()); + inner_dims_ = inner_dims_opt.getData(); + + as_ = tensor_roles_.at(MatmulTensorRole::OPERAND_A); + bs_ = tensor_roles_.at(MatmulTensorRole::OPERAND_B); + + countDims(); +} + +void MultipleMatmulScheduler::countDims() { + NVF_ERROR(!patterns_.empty()); + TensorView* mma_result = patterns_.front().output; + num_device_dims_ = numDeviceDims(mma_result); + for (const auto& it : id_roles_) { + if (it.second == MatmulDimRole::Batch && + // Skip device dims + !std::any_of(it.first->begin(), it.first->end(), [](Val* v) { + return v->as()->isDeviceDim(); + })) { + // All batch dims will be merged into one, if any exist + num_local_batch_dims_ = 1; + } + } + num_splitk_dims_ = params_->splitk_factor > 1 ? 1 : 0; + // Subtract 6 for the [Mo, No, Ko, Mi, Ni, Ki] + num_device_and_batch_dims_ = num_device_dims_ + num_local_batch_dims_; +} + +//! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer +//! to the new IdModel. This is necessary whenever we perform an operation +//! that creates a new TensorView, such as caching or rFactor +void MultipleMatmulScheduler::updateIdModel() { + // Build new IdModel + IdModel new_id_model(fusion_, /*build_graphs=*/false); + new_id_model.buildPermissiveGraph(); + + // Get new permissive graph + ValGraph& new_graph = new_id_model.idGraph(IdMappingMode::PERMISSIVE); + + if (!id_roles_.empty()) { + // Update id_roles_ to have keys corresponding to ValGroups in the new + // IdModel + std::unordered_map new_id_roles; + for (auto& [k, v] : id_roles_) { + const ValGroup& new_group = new_graph.toGroup(k->front()); + new_id_roles.emplace(new_group, v); + } + id_roles_ = new_id_roles; + } + + graph_ = &new_id_model.idGraph(IdMappingMode::PERMISSIVE); + + // Set id_model_ after we are done using the old one + id_model_ = std::move(new_id_model); +} + void scheduleMultipleMatmuls(Fusion* fusion, const MatmulParams* params) { FusionGuard fg(fusion); diff --git a/csrc/scheduler/multi_matmul.h b/csrc/scheduler/multi_matmul.h index 2ed7bc15cb7..af54476e9d6 100644 --- a/csrc/scheduler/multi_matmul.h +++ b/csrc/scheduler/multi_matmul.h @@ -9,10 +9,64 @@ #include #include +#include +#include +#include #include namespace nvfuser { +// Base class for AmpereMultipleMatmulScheduler and +// HopperMultipleMatmulScheduler +class MultipleMatmulScheduler { + public: + MultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params) + : fusion_(fusion), + params_(params), + id_model_(fusion, /*build_graphs=*/false) {} + virtual ~MultipleMatmulScheduler() = default; + + virtual void run() = 0; + + protected: + void findPatterns(); + + void translatePatterns(); + + // Get tensor roles and id roles + // When there are multiple matmul patterns, we can have conflicting roles. + // For now we throw an error if this is the case. + // TODO: This should be checked in canScheduleCompileTime + void findRoles(); + + void countDims(); + + //! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer + //! to the new IdModel. This is necessary whenever we perform an operation + //! that creates a new TensorView, such as caching or rFactor + void updateIdModel(); + + protected: + Fusion* fusion_; + const MatmulParams* params_; + IdModel id_model_; + + // Permissive graph of id_model_, which we modify at times using e.g. + // AbstractTensor.split or by mapping vals in cacheAfter and rFactor + ValGraph* graph_ = nullptr; + std::vector patterns_; + mma_utils::DimRolesMap id_roles_; + mma_utils::TensorRolesMap tensor_roles_; + mma_utils::MatmulOperandInnerDims inner_dims_; + + int64_t num_splitk_dims_ = 0; + int64_t num_device_dims_ = 0; + int64_t num_local_batch_dims_ = 0; + int64_t num_device_and_batch_dims_ = 0; + + std::vector as_, bs_, mma_results_; +}; + NVF_API void scheduleMultipleMatmuls( Fusion* fusion, const MatmulParams* mparams); diff --git a/csrc/scheduler/tools/abstract_tensor.h b/csrc/scheduler/tools/abstract_tensor.h index 087d8a1b8cd..834ed652785 100644 --- a/csrc/scheduler/tools/abstract_tensor.h +++ b/csrc/scheduler/tools/abstract_tensor.h @@ -34,6 +34,14 @@ using AbstractId = dynamic_type::DynamicType< IterDomain*, ValGroupAndItsGraph>; +inline IterDomain* representativeId(const AbstractId& abs_id) { + if (abs_id.is()) { + return abs_id.as(); + } + NVF_ERROR(abs_id.is()); + return representativeId(abs_id.as().group); +} + namespace { struct DispatchSplit { @@ -338,7 +346,7 @@ struct DispatchParallelize { // AbstractTensor is similar to TensorView, it has multiple dimensions, where // each dimension is represented by an Abstract IterDomain. The interface of -// AbstractTensor is also similar to that of TesorViews, that is, it has merge, +// AbstractTensor is also similar to that of TensorViews, that is, it has merge, // split, etc. However, it only has a single "domain", instead of having // multiple domains like "logical domain", "loop domain", etc. // diff --git a/csrc/utils.h b/csrc/utils.h index 56a07cced24..f98d2e357a2 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -589,6 +589,11 @@ T pow(T a, T b) { } } +// Returns true if given number is power of 2 +constexpr bool isPowOf2(int64_t x) { + return x > 1 && (x & (x - 1)) == 0; +} + template using MaybeUniqueOwningPtr = dynamic_type:: DynamicType>; From e33316d9480508b49db788a7472f4df52e53af92 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Tue, 29 Oct 2024 10:14:20 -0700 Subject: [PATCH 09/17] Only the TMA thread arrive (#3294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously: ```C++ if (elect-sync) { arriveExpectTx TMA } else { arrive } ``` Now: ```C++ if (elect-sync) { arriveExpectTx TMA } ``` I am very surprised that this fixes all the latencies introduced in the elect-sync fix https://github.com/NVIDIA/Fuser/pull/3295, and even better! But in general, we should sync as less as possible, and avoid unnecessary wait, so I think this PR makes sense. Perf: ```C++ Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 39.0 172735 1 172735.0 172735.0 172735 172735 0.0 ::nvfuser_none_f0_c0_r0_g0(::Tensor<::__half, (int)3, (int)3>, … 20.0 88768 1 88768.0 88768.0 88768 88768 0.0 nvjet_hsh_256x128_64x4_1x2_h_bz_coopA_NTT ``` Perf nvFuser/cuBLAS: `51.4%`. --- csrc/device_lower/pass/allocation.cpp | 23 ++--------------- csrc/device_lower/pass/circular_buffer.cpp | 30 +++------------------- 2 files changed, 6 insertions(+), 47 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index b2a9ec7b3f0..bd845a27066 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -67,25 +67,6 @@ Expr* initializeMbarrier( kir::TensorIndex* stage_mbarrier = IrBuilder::create(all_mbarriers, loop->index()); - // Get all threads in CTA - Val* bdimx = - GpuLower::current()->parallelDimensionMap().get(ParallelType::TIDx); - Val* bdimy = - GpuLower::current()->parallelDimensionMap().get(ParallelType::TIDy); - Val* bdimz = - GpuLower::current()->parallelDimensionMap().get(ParallelType::TIDz); - Val* all_threads_in_cta = SimplifyingIrBuilder::mulExpr( - bdimx, SimplifyingIrBuilder::mulExpr(bdimy, bdimz)); - if (all_threads_in_cta != nullptr) { - all_threads_in_cta = SimplifyingIrBuilder::maybeCastExpr( - DataType::UInt32, all_threads_in_cta); - } else { - // If all_threads_in_cta is nullptr, then this kernel is not parallelized - // on any of the thread dimensions. - all_threads_in_cta = - GpuLower::current()->kernel()->oneVal(DataType::UInt32); - } - auto circular_buffered_tvs = GpuLower::current()->circularBufferInfo().getCircularBufferTvs( circular_buffer_loop); @@ -95,8 +76,8 @@ Expr* initializeMbarrier( [](const TensorView* tv) { return ir_utils::isCpAsyncBulkLoad(tv->definition()); }); - Val* n = IrBuilder::create(num_of_tvs_loaded_by_tma, DataType::UInt32); - Val* num_of_arrives = SimplifyingIrBuilder::mulExpr(n, all_threads_in_cta); + Val* num_of_arrives = + IrBuilder::create(num_of_tvs_loaded_by_tma, DataType::UInt32); // Initialize mbarrier for each circular buffer stage. Use the thread // count from the MBarrierInit created in the allocation pass. The wait diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 415d10a02ee..40bb8fc3294 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -261,27 +261,20 @@ class CircularBufferLoopCloner : public kir::IrVisitor { // Detailed Pseudo-Code: // Pre-Prologue loop: // -// - number_of_arrival_threads is the number of threads to call -// mbarrier::arrive or mbarrier::arriveExpectTx and to wait at -// mbarrier:wait. -// // __shared__ __mbarrier_t barriers[num_stages]; // if (warp_id == 0 && electSync()()) { // for (int64_t loop_index : irange(stages)) { -// int64_t number_of_arrive_threads = blockDim.x * blockDim.y * blockDim.z; -// mbarrier_init(mbarrier[loop_index], number_of_arrival_threads); +// mbarrier_init(mbarrier[loop_index], number_of_tma_load_exprs); // } // } // // Prologue loop: // for (int64_t loop_index : irange(prefetch_distance)) { -// if (warp_id == 0 && electSync()()) { +// if (warp_id == 0 && electSync()) { // mbarrier::arriveExpectTx(mbarrier[loop_index], expected_bytes); // for (...) { // cpAsyncBulk(mbarriers[loop_index], ...); // } -// } else { -// mbarrier::arrive(mbarrier[loop_index]); // } // } // @@ -294,8 +287,6 @@ class CircularBufferLoopCloner : public kir::IrVisitor { // for (...) { // cpAsyncBulk(mbarrier[load_stage], ...); // } -// } else { -// mbarrier::arrive(mbarrier[load_stage]); // } // mbarrier::waitParity((loop_index / stage_depth) % 2); // @@ -363,8 +354,8 @@ class CloneTmaCircularBufferLoopAndInsertSync // generate the nested for-loops for the serial IterDomains, but do not add // them to the cloned circular buffer loop immediately. Once the cloned // circular buffer loop is the only loop in the stack, add the arriveExpectTx - // and arrive expressions, then the nested for-loop structure calling the TMA - // load operations, and finally the mbarrier_wait. + // expressions, then the nested for-loop structure calling the TMA load + // operations, and finally the mbarrier_wait. void processForLoop(ForLoop* cloned_loop) final { // Skip if there is not an active for-loop structure if (for_loop_stack_.empty()) { @@ -412,8 +403,6 @@ class CloneTmaCircularBufferLoopAndInsertSync // for (...) { // cpAsyncBulk; // } - // } else { - // arrive; // } NVF_ERROR(for_loop_stack_.front() == cloned_top_level_loop_); addTmaLoadBlock(cloned_loop); @@ -602,8 +591,6 @@ class CloneTmaCircularBufferLoopAndInsertSync // for (...) { // cpAsyncBulk(mbarriers[loop_index], ...); // } - // } else { - // mbarrier::arrive(mbarrier[loop_index]); // } // } void handlePrologueLoop(Expr* expr) { @@ -656,8 +643,6 @@ class CloneTmaCircularBufferLoopAndInsertSync // for (...) { // cpAsyncBulk(mbarrier[load_stage], ...); // } - // } else { - // mbarrier::arrive(mbarrier[load_stage]); // } // mbarrier::wait((loop_index / stage_depth) % 2); // @@ -725,8 +710,6 @@ class CloneTmaCircularBufferLoopAndInsertSync // for (...) { // cpAsyncBulk(mbarrier[next_stage], ...); // } - // } else { - // mbarrier::arrive(mbarrier[next_stage]); // } // // The expr input argument can be a single cpAsyncBulk expression or a nested @@ -745,11 +728,6 @@ class CloneTmaCircularBufferLoopAndInsertSync // launches the TMA load. if_expr->thenBody().push_back(mbarrier_arrive_tx_); if_expr->thenBody().push_back(expr); - - // The other threads issue arriveExpectTx without any expected transactions. - kir::MBarrierArrive* thread_arrive = IrBuilder::create( - /*state=*/nullptr, mbarrier_arrive_tx_->mbarrier()); - if_expr->elseBody().push_back(thread_arrive); for_loop_stack_.back()->body().push_back(if_expr); mbarrier_arrive_tx_ = nullptr; From e4c98249f4606201eddda7a7106d1cd372ea9aa0 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:37:41 -0400 Subject: [PATCH 10/17] fix padded bdimx to use warp reduction in inner reduction scheduler (#3288) Simple fix to padded bdimx in inner reduction scheduler. **Performance changes:** **(1) H100** 5 lines corresponds to batch size of 16, 512, 2048, 8192, 16384 ![image](https://github.com/user-attachments/assets/225e2ed1-0c99-402d-acce-894f728e083b) **(2) A100** ![image](https://github.com/user-attachments/assets/0658a933-dc5e-4cf3-af48-4b95289b0977) --- csrc/scheduler/reduction.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 87f9d2bffad..4032c9d159e 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -388,9 +388,6 @@ std::unique_ptr innerReductionHeuristic( bool pad_bdimx = bdimx > 16 && bdimx * bdimy < (int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; - // If barely just covering reduction dim, don't pad to the next warp - pad_bdimx = pad_bdimx && - bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel; rparams->pad_inner_reduction_to_warp = pad_bdimx; if (rparams->pad_inner_reduction_to_warp) { From 7220207cae9ea9ab79814d4b5d6d8532fa629875 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:36:19 -0700 Subject: [PATCH 11/17] Remove deprecated `clear_cuda_cache` (#3306) This benchmark was added recently and did not have the changes added by PR #3252. The benchmark will fail on the CI due to missing import function --- benchmarks/python/test_adaptive_layernorm_host.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/benchmarks/python/test_adaptive_layernorm_host.py b/benchmarks/python/test_adaptive_layernorm_host.py index 7e3c67b6d8b..7e60da7659d 100644 --- a/benchmarks/python/test_adaptive_layernorm_host.py +++ b/benchmarks/python/test_adaptive_layernorm_host.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import clear_cuda_cache from .core import run_benchmark import torch @@ -73,8 +72,6 @@ def test_adaptive_layernorm_fwd_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - B = 1 T = 30 * 1024 D = 1024 From c14d4181ae499c90b5728c06dfac248a5afab8b0 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 29 Oct 2024 22:07:09 -0700 Subject: [PATCH 12/17] Profile configurations for InnerOuterPersistent scheduler in python frontend (#3118) # Summary This PR explores auto-tuning a `LayerNormBackward` fusion using the `InnerOuterPersistent` scheduler in the python-frontend. - Create `autotune_persistent.py` to test several parameter configurations then apply `DecisionTreeRegressor` - The selected performance metric is `effective_bandwidth_gbs`. The empirical scheduler selects the configuration that has the highest predicted `effective_bandwidth_gbs`. # Key differences from approach for `Pointwise` scheduler - `vectorize_factor`, `thread_per_block_min`, and `thread_per_block_max` are specified before running `computeHeuristics`. These settings are akin to hyper-parameters used to constrain the generated scheduler parameters. - Create `SchedulerHyperParameters` as an entry in `HeuristicDataCache` to specify these constraints when generating scheduler parameters. # Details 1. Create `struct SchedulerHyperParameters` in `csrc/scheduler/utils.h` 2. Create `HeuristicDataCacheEntry` in `csrc/scheduler/compile_time_info.h` 3. Modify `computeHeuristics` to use hyper-parameter constraints. 4. Expose `SchedulerHyperParameters` in python frontend. 5. Allow user schedulers to define a `HeuristicDataCache` during scheduling. * `ScheduleHyperParameters` contains parameters for `vectorize_factor`, `unroll_factor`, `threads_per_block_min`, and `threads_per_block_max`. --- csrc/python_frontend/fusion_cache.cpp | 3 +- csrc/python_frontend/fusion_cache.h | 4 + csrc/python_frontend/fusion_definition.cpp | 4 + csrc/python_frontend/python_bindings.cpp | 63 ++- csrc/scheduler/compile_time_info.h | 12 +- csrc/scheduler/normalization_inner_outer.cpp | 75 +++- csrc/scheduler/registry.cpp | 2 + csrc/scheduler/utils.h | 28 ++ .../python_scheduling/autotune_persistent.py | 417 ++++++++++++++++++ 9 files changed, 590 insertions(+), 18 deletions(-) create mode 100644 doc/dev/python_scheduling/autotune_persistent.py diff --git a/csrc/python_frontend/fusion_cache.cpp b/csrc/python_frontend/fusion_cache.cpp index 53dc43bdbe8..83ce851dbab 100644 --- a/csrc/python_frontend/fusion_cache.cpp +++ b/csrc/python_frontend/fusion_cache.cpp @@ -227,7 +227,8 @@ HeuristicParams* UserSchedule::computeHeuristics(SchedulerType scheduler_type) { NVF_CHECK( heuristic_params == nullptr, "Heuristic Scheduler is already defined for this UserSchedule"); - heuristic_params = scheduler->computeHeuristics(fusion(), runtime_info_ref); + heuristic_params = scheduler->computeHeuristics( + fusion(), runtime_info_ref, data_cache.get()); return heuristic_params.get(); } diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index b4283b7bdaf..190671b2b82 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -33,6 +34,9 @@ struct UserSchedule { //! The parameters for scheduler heuristic. std::unique_ptr heuristic_params; + //! The compile-time data cache. + std::unique_ptr data_cache; + //! Concretized, Scheduled Fusion IR std::unique_ptr scheduled_fusion; diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index b512d9d761b..09648a0bf36 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -239,6 +240,9 @@ void FusionDefinition::setupSchedule( user_sched_ = fusionCache()->createUserSchedule( scheds, inputs, device, overwrite_existing_schedule); + // Create scheduler data cache + user_sched_->data_cache = std::make_unique(); + // Building a new Fusion container for scheduling with definition such that // the definition's tensor data members refer to the corresponding IR objects // needed for scheduling. A simple copy of the container would mean the data diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index d7b9e8d1c34..b229107c45b 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -23,9 +23,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -779,6 +781,44 @@ void initNvFuserPythonBindings(PyObject* module) { defineHeuristicParamBindings(nvfuser); + py::class_ hyperparameters( + nvfuser, "SchedulerHyperParameters"); + hyperparameters.def(py::init()); + hyperparameters.def_property( + "vectorize_factor", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.vectorize_factor; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t vectorize_factor_) { + self.vectorize_factor = vectorize_factor_; + }); + hyperparameters.def_property( + "unroll_factor", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.unroll_factor; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t unroll_factor_) { self.unroll_factor = unroll_factor_; }); + hyperparameters.def_property( + "threads_per_block_min", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.threads_per_block_min; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t threads_per_block_min_) { + self.threads_per_block_min = threads_per_block_min_; + }); + hyperparameters.def_property( + "threads_per_block_max", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.threads_per_block_max; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t threads_per_block_max_) { + self.threads_per_block_max = threads_per_block_max_; + }); + //! KernelProfiles are encapsulated in FusionProfiles where each KP //! is associated with a segment. py::class_ kernel_prof(nvfuser, "KernelProfile"); @@ -1401,7 +1441,7 @@ void initNvFuserPythonBindings(PyObject* module) { py::class_ nvf_ops(fusion_def, "Operators"); nvf_ops.def(py::init()); - // ******************** INSERT OP BINDINGS BELOW HERE ******************** +// ******************** INSERT OP BINDINGS BELOW HERE ******************** #define OP_PREFIX "Operators." #define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \ nvf_ops.def( \ @@ -3822,6 +3862,27 @@ void initNvFuserPythonBindings(PyObject* module) { return *parameters->as(); }, py::return_value_policy::reference); + nvf_sched.def( + "schedule_hyperparameters", + [](FusionDefinition::SchedOperators& self) + -> scheduler_utils::SchedulerHyperParameters& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>( + sched->data_cache.get(), []() { + return std::make_unique< + scheduler_utils::SchedulerHyperParameters>( + /*vectorize_factor=*/1, + /*unroll_factor=*/1, + /*threads_per_block_min=*/1, + /*threads_per_block_max=*/1); + }); + return scheduler_hyperparameters_entry.get(); + }, + py::return_value_policy::reference); } void cleanup() { diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index 18b5efb0e8e..d413c99ae81 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -46,7 +46,8 @@ enum class CompileTimeEntryType { CAN_SCHEDULE_TRANSPOSE, CAN_SCHEDULE_MUL_SUM_AS_MMA, LOGICAL_REORDER_MAP, - VECTORIZATION_BREAK_POINT_OF_RED_PROD + VECTORIZATION_BREAK_POINT_OF_RED_PROD, + SCHEDULE_HYPERPARAMETERS }; //! Entry type definition class for `DOMAIN_MAP`, @@ -203,6 +204,15 @@ class VectorizationBreakPointOfReductionProducer { CompileTimeEntryType::VECTORIZATION_BREAK_POINT_OF_RED_PROD; }; +//! Entry type definition class for `SCHEDULE_HYPERPARAMETERS`, +//! stores hyperparameters for SchedulerEntry::computeHeuristics +class SchedulerHyperParameters { + public: + using DataType = scheduler_utils::SchedulerHyperParameters; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::SCHEDULE_HYPERPARAMETERS; +}; + //! Base abstract class for unified storage in `HeuristicDataCache`, //! each entry in `HeuristicDataCache` will be a subclass. class CompileTimeInfoBase : public PolymorphicBase { diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 6dd34f4cab9..2ea854f0a88 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -186,7 +186,9 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( SchedulerRuntimeInfo& runtime_info, HeuristicDataCache* data_cache, const std::vector& reduction_tvs, - const int64_t vectorize_factor) { + const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max) { FUSER_PERF_SCOPE( "normalization_inner_outer::getPersistentBufferStorageParams"); @@ -230,9 +232,7 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( const auto dev_prop = at::cuda::getCurrentDeviceProperties(); int64_t smem_overhead = scheduler_utils::getSharedMemoryOverheadPerBlock( - fusion, - reduction_tvs, - InnerOuterPersistentKernelScheduler::threads_per_block_max); + fusion, reduction_tvs, threads_per_block_max); int64_t available_smem = (int64_t)dev_prop->sharedMemPerMultiprocessor - smem_overhead; int64_t available_regs = scheduler_utils::register_file_size_56k; @@ -281,8 +281,8 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( tv_buffer_size_regs, dataTypeSize(current_tv->getDataType().value()), vectorize_factor, - InnerOuterPersistentKernelScheduler::threads_per_block_min, - InnerOuterPersistentKernelScheduler::threads_per_block_max, + threads_per_block_min, + threads_per_block_max, dev_prop->warpSize); buffer_params.smem_buffer_size += tv_buffer_size_smem; @@ -332,6 +332,8 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( const int64_t outer_dim_numel, const int64_t persistent_buffer_size, const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max, const int64_t warp_size) { // if inner_dim_numel <= 1024, we are doing multiple reductions per block // with a constant batch size of 1 if vectorized. See Step 5 of @@ -380,11 +382,8 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( }; const int64_t after_vectorization = inner_dim_numel / vectorize_factor; - const int64_t threads_per_block_min = std::min( - after_vectorization, - InnerOuterPersistentKernelScheduler::threads_per_block_min); - const int64_t threads_per_block_max = - InnerOuterPersistentKernelScheduler::threads_per_block_max; + const int64_t threads_per_block_min_after_vectorization = + std::min(after_vectorization, threads_per_block_min); const int64_t batch_min = getMinimumBatch(); const int64_t batch_max = getMaximumInnerOuterPersistentBufferBatch(); @@ -392,7 +391,7 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( // is larger than batch_max, try increase threads per block by a warp until // the threads_per_block reaches threads_per_block_max or the batch size // reaches batch_min. - int64_t threads_per_block = threads_per_block_min; + int64_t threads_per_block = threads_per_block_min_after_vectorization; int64_t inner_batch = ceilDiv(after_vectorization, threads_per_block); while (inner_batch > batch_max && threads_per_block + warp_size <= threads_per_block_max && @@ -432,6 +431,8 @@ std::unique_ptr innerOuterPersistentHeuristic( const int64_t smem_overhead, const size_t tmp_gmem_dtype_size, const size_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max, const bool project_to_input, const PrimDataType index_type) { auto rparams = std::make_unique( @@ -512,6 +513,8 @@ std::unique_ptr innerOuterPersistentHeuristic( outer_dim_numel, regs_buffer_size, iop.inner_vect, + threads_per_block_min, + threads_per_block_max, dev_prop->warpSize); iop.inner_batch = persistent_batch; @@ -743,12 +746,32 @@ std::unique_ptr getInnerOuterPersistentHeuristics( scheduler_utils::persistentBuffers(fusion)); }); + auto scheduler_hyperparameters_entry = + HeuristicDataCacheEntry( + data_cache, [&]() { + return std::make_unique( + /*vectorize_factor=*/vectorize_factor, + /*unroll_factor=*/1, + /*threads_per_block_min=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_min, + /*threads_per_block_max=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_max); + }); + scheduler_utils::SchedulerHyperParameters& hp = + scheduler_hyperparameters_entry.get(); + auto& persistent_buffer_info = persistent_buffer_info_entry.get(); NVF_ERROR( !persistent_buffer_info.persistent_buffers.empty(), "Persistent scheduler requires persistent buffers."); auto buffer_params = getPersistentBufferStorageParams( - fusion, runtime_info, data_cache, reduction_tvs, vectorize_factor); + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); std::unique_ptr rparams = innerOuterPersistentHeuristic( properties.total_iteration_numel, @@ -757,7 +780,9 @@ std::unique_ptr getInnerOuterPersistentHeuristics( buffer_params.smem_buffer_size, buffer_params.smem_overhead, max_outer_reduction_dtype_size, - vectorize_factor, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max, buffer_params.project_to_input, runtime_info.getIndexType()); @@ -1244,9 +1269,29 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( data_cache, (int)(reduced_tv->nDims() - properties.inner_most_dimension_ndims)); + auto scheduler_hyperparameters_entry = + HeuristicDataCacheEntry( + data_cache, [&]() { + return std::make_unique( + /*vectorize_factor=*/vectorize_factor, + /*unroll_factor=*/1, + /*threads_per_block_min=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_min, + /*threads_per_block_max=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_max); + }); + scheduler_utils::SchedulerHyperParameters& hp = + scheduler_hyperparameters_entry.get(); + // check if there is enough register and shared memory for persistence const auto buffer_params = getPersistentBufferStorageParams( - fusion, runtime_info, data_cache, reduction_tvs, vectorize_factor); + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 32cb0aa30f0..3f6b5827db5 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -224,4 +224,6 @@ template class HeuristicDataCacheEntry< template class HeuristicDataCacheEntry; template class HeuristicDataCacheEntry< HeuristicCompileTime::VectorizationBreakPointOfReductionProducer>; +template class HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>; } // namespace nvfuser diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 5dab953dead..77317cde31b 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -173,6 +173,34 @@ inline void parallelizeAllLike( propagate_padding); } +// Common hyperparameters used in heuristic scheduler. These hyperparameters +// are passed to SchedulerEntry::computeHeuristics through the +// HeuristicDataCache. These hyperparameters alter the generation of the +// HeuristicParams for the scheduler. +struct SchedulerHyperParameters { + SchedulerHyperParameters( + int64_t vectorize_factor_, + int64_t unroll_factor_, + int64_t threads_per_block_min_, + int64_t threads_per_block_max_) + : vectorize_factor(vectorize_factor_), + unroll_factor(unroll_factor_), + threads_per_block_min(threads_per_block_min_), + threads_per_block_max(threads_per_block_max_) {} + + //! Number of elements to load per vectorize load. + int64_t vectorize_factor = 1; + + //! Number of iterations to unroll for-loop. + int64_t unroll_factor = 1; + + //! Minimum number of threads per block. + int64_t threads_per_block_min = 1; + + //! Maximum number of threads per block. + int64_t threads_per_block_max = 1; +}; + struct PersistentBufferInfo { std::vector persistent_buffers; std::unordered_set unmappable_dims; diff --git a/doc/dev/python_scheduling/autotune_persistent.py b/doc/dev/python_scheduling/autotune_persistent.py new file mode 100644 index 00000000000..9cb02c6c0e7 --- /dev/null +++ b/doc/dev/python_scheduling/autotune_persistent.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch +import itertools +import random +from nvfuser import FusionCache, FusionDefinition, SchedulerType, DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +from copy import deepcopy + +# ============================ Description ============================ + +# 1. Define a nvfuser fusion and its pytorch eager mode reference. +# +# 2. Profile the CUDA kernel performance by iterating over a set of input +# arguments and scheduler configurations. +# +# 3. Train a regression model to predict the desired performance metric given +# some input arguments and a scheduler configuration. +# +# 4. Measure the performance of the regression model. +# - Calculate RMSE of predicted and actual performance on test set. +# - Find the configuration with the best performance using regression model. +# Then, compare against the heuristic configuration selected by nvfuser. +# - For a specific batch size, gather performance across a range of hidden +# sizes. Calculate performance for best predicted and nvfuser +# configurations. Plot a chart comparing performance using matplotlib. + +# The selected performance metric is effective_bandwidth_gbs. The empirical +# scheduler selects the configuration that has the highest predicted +# effective_bandwidth_gbs. + +# ============================ Configurations ============================ + +# Settings for input tensor generation +num_dimensions = 2 +outer_shapes = [256, 1024, 4096, 16384] +inner_shapes = [2**i for i in range(10, 15)] + +# For pointwise scheduler, we test the cartesian product of vectorization and +# cta_size factors. +parameter_configurations = [ + vectorize_range := [1, 2, 4, 8], + threads_per_cta_range := list(range(128, 288, 32)), +] + +# We profile a range of input shapes with various configurations. +# This argument determines how much of the profiled data to keep as a test set. +test_data_percentage = 0.1 + +# The selected batch size for empirical and nvfuser comparison. +empirical_batch_size = 512 + +# The range of hidden sizes for empirical and nvfuser comparision. +empirical_hidden_sizes = list(range(1024, 28672, 256)) + +# NOTE For 24gb memory limit +# empirical_hidden_sizes = list(range(256, 22784, 256)) + + +def create_inputs(shape): + """Create input arguments for nvfuser fusion and eager mode""" + a = torch.randn(*shape, dtype=torch.bfloat16, device="cuda", requires_grad=True) + grads = torch.randn(*shape, dtype=torch.bfloat16, device="cuda") + weights = torch.randn( + shape[1], dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + bias = torch.randn( + shape[1], dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + + eps = 1e-5 + mean = a.to(torch.float).mean(dim=-1) + variance = a.to(torch.float).var(dim=-1, unbiased=False) + invstd = (1.0 / torch.sqrt(variance + eps)).unsqueeze(1) + + nvf_inputs = [a, grads, mean, invstd, weights] + eager_inputs = [a, weights, bias, grads] + return nvf_inputs, eager_inputs + + +# A decorator to create a pointwise fusion given some input arguments. +def create_fusion_func(inputs): + PROMOTE_DTYPES = [DataType.BFloat16, DataType.Half] + dtype = torch_dtype_to_nvfuser_dtype(inputs[0].dtype) + + def fusion_func(fd: FusionDefinition): + T0 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + T1 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + + T2 = fd.define_tensor( + shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False + ) + T3 = fd.define_tensor( + shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False + ) + + T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False) + + if dtype in PROMOTE_DTYPES: + T0 = fd.ops.cast(T0, dtype=DataType.Float) + T1 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.cast(T4, dtype=DataType.Float) + + V8 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) + T9 = fd.ops.broadcast_in_dim(T2, shape=V8, broadcast_dims=[0]) + V12 = T0.shape() + T13 = fd.ops.broadcast_in_dim(T9, shape=V12, broadcast_dims=[0, 1]) + T14 = fd.ops.sub(T0, T13) + + T18 = fd.ops.broadcast_in_dim(T3, shape=V12, broadcast_dims=[0, 1]) + T19 = fd.ops.mul(T14, T18) + + T23 = fd.ops.broadcast_in_dim(T4, shape=V12, broadcast_dims=[1]) + T28 = fd.ops.sum(T1, dims=[0], keepdim=False, dtype=DataType.Null) + + T30 = fd.ops.mul(T1, T23) + T31 = fd.ops.mul(T1, T19) + T32 = fd.ops.sum(T31, dims=[0], keepdim=False, dtype=DataType.Null) + + T34 = fd.ops.mul(T30, T18) + T35 = fd.ops.mul(T30, T14) + T36 = fd.ops.sum(T35, dims=[1], keepdim=False, dtype=DataType.Null) + + T40 = fd.ops.broadcast_in_dim(T36, shape=V8, broadcast_dims=[0]) + T41 = fd.ops.neg(T34) + T42 = fd.ops.sum(T41, dims=[1], keepdim=False, dtype=DataType.Null) + T46 = fd.ops.broadcast_in_dim(T42, shape=V8, broadcast_dims=[0]) + S47 = fd.define_scalar(-0.500000, dtype=DataType.Double) + T48 = fd.ops.mul(S47, T40) + S49 = fd.define_scalar(3.00000, dtype=DataType.Double) + T50 = fd.ops.pow(T3, S49) + T51 = fd.ops.mul(T48, T50) + T54 = fd.ops.sum(T46, dims=[1], keepdim=False, dtype=DataType.Null) + T55 = fd.ops.sum(T51, dims=[1], keepdim=False, dtype=DataType.Null) + + T59 = fd.ops.broadcast_in_dim(T55, shape=V8, broadcast_dims=[0]) + T63 = fd.ops.broadcast_in_dim(T59, shape=V12, broadcast_dims=[0, 1]) + T67 = fd.ops.broadcast_in_dim(T2, shape=V8, broadcast_dims=[0]) + T71 = fd.ops.broadcast_in_dim(T67, shape=V12, broadcast_dims=[0, 1]) + + S72 = fd.define_scalar(2.00000, dtype=DataType.Double) + T73 = fd.ops.mul(S72, T63) + T74 = fd.ops.sub(T0, T71) + T75 = fd.ops.mul(T73, T74) + + S77 = fd.ops.reciprocal(T0.size(1)) + T78 = fd.ops.mul(T75, S77) + T82 = fd.ops.broadcast_in_dim(T54, shape=V8, broadcast_dims=[0]) + T86 = fd.ops.broadcast_in_dim(T82, shape=V12, broadcast_dims=[0, 1]) + T88 = fd.ops.mul(S77, T86) + T89 = fd.ops.add(T78, T88) + T90 = fd.ops.add(T34, T89) + + if dtype in PROMOTE_DTYPES: + T28 = fd.ops.cast(T28, dtype=dtype) + T90 = fd.ops.cast(T90, dtype=dtype) + T32 = fd.ops.cast(T32, dtype=dtype) + + fd.add_output(T90) + fd.add_output(T32) + fd.add_output(T28) + + return fusion_func + + +# The pytorch eager mode reference used to validating nvfuser kernel. +def eager_reference(inputs): + inputs_cloned = deepcopy(inputs) + a, weights, bias, grad_output = inputs_cloned + eager_output = torch.nn.functional.layer_norm( + a.to(torch.double), + a.shape[1:], + weight=weights.to(torch.double), + bias=bias.to(torch.double), + ) + grad_output = grad_output.to(torch.double) + eager_output.backward(grad_output) + return [a.grad, weights.grad, bias.grad] + + +# ============================ Function Definitions ============================ + + +# Apply scheduler with custom parameters using decorator +def custom_persistent_scheduler(fd, config): + def inner_fn(): + # Check if compatible with persistent scheduler + status, _ = fd.sched.can_schedule(SchedulerType.inner_outer_persistent) + assert status + + # Modify original parameters + if config is not None: + hyperparameters = fd.sched.schedule_hyperparameters() + vectorize_factor, threads_per_block = config + hyperparameters.vectorize_factor = vectorize_factor + hyperparameters.threads_per_block_min = threads_per_block + hyperparameters.threads_per_block_max = threads_per_block + + # Schedule fusion + fd.sched.schedule(SchedulerType.inner_outer_persistent) + + fd.schedule = inner_fn + return fd + + +# Apply schedule decorator, run fusion, and profile performance +def run_profile(presched_fd, nvf_inputs, eager_inputs, config=None): + scheduled_fd = custom_persistent_scheduler(presched_fd, config) + nvf_outputs = scheduled_fd.execute(nvf_inputs, profile=True) + + # validate correctness + ref_outputs = eager_reference(eager_inputs) + for nvf_out, ref_out in zip(nvf_outputs, ref_outputs): + assert torch.allclose(nvf_out, ref_out, atol=1e-1, rtol=1e-1) + + prof = scheduled_fd.profile() + bandwidth = prof.kernel_profiles[0].effective_bandwidth_gbs + time = prof.kernel_profiles[0].time_ms + return bandwidth, time + + +def argmax(map_config_to_perf): + best_perf = -1 + best_config = None + for config, perf in map_config_to_perf.items(): + if perf > best_perf: + best_perf = perf + best_config = config + return best_config + + +# Given a prediction model, input_shape, and set of parameter configurations, +# find the best parameters +def find_best_parameters(predictor, input_shape, parameter_configurations): + map_config_to_performance = { + config: predictor.predict([[*input_shape, *config]]) + for config in itertools.product(*parameter_configurations) + } + return argmax(map_config_to_performance) + + +# ============================ Run Experiments ================================ + +# Collect data for decision tree +parameters = [] +performance = [] + +for shape in itertools.product(outer_shapes, inner_shapes): + print(shape) + nvf_inputs, eager_inputs = create_inputs(shape) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + # vectorization and threads_per_cta configurations + for config in itertools.product(*parameter_configurations): + perf_metric, _ = run_profile(presched_fd, nvf_inputs, eager_inputs, config) + parameters.append((*shape, *config)) + performance.append(perf_metric) + +# ============================ Separate Data ================================== + +# Separate collected data into training and test sets +train_data = [] +test_data = [] +train_perf = [] +test_perf = [] +test_shapes = set() +all_test_config = {} # key: input_shape, value: (config, perf) + +for data, perf in zip(parameters, performance): + shape = data[:num_dimensions] + config = data[num_dimensions:] + + if shape in all_test_config: + all_test_config[shape][config] = perf + else: + all_test_config[shape] = {config: perf} + + if random.random() < test_data_percentage: + test_data.append(data) + test_perf.append(perf) + else: + test_shapes.add(shape) + train_data.append(data) + train_perf.append(perf) + +# key: input_shape, value: best_config +best_test_config = {shape: argmax(all_test_config[shape]) for shape in test_shapes} + +# ========================= Train Regression Models =========================== + +# Apply decision tree regressor +# Given input shapes and scheduler parameters, predict performance metric. +from sklearn import tree + +clf = tree.DecisionTreeRegressor() +clf = clf.fit(train_data, train_perf) +test_pred = clf.predict(test_data) + +print("===================== measure performance rmse ========================") + +# Estimate prediction error with RMSE +import numpy as np + +test_perf = np.array(test_perf) +print( + "Test prediction error (RMSE)", + np.sqrt(np.mean(np.power(test_perf - test_pred, 2))), +) +print("Test performance", test_perf) +print("Test prediction", test_pred) + +print("======================= compare configurations =======================") +# Find best configuration for test_shapes +print( + "input shape, estimate_config:(vectorization, cta_size), actual_config:(vectorization, cta_size), correct" +) +correctness_count = 0 +mismatch_configs = [] +for shape in test_shapes: + estimate_config = find_best_parameters(clf, shape, parameter_configurations) + + match_config = estimate_config == best_test_config[shape] + if not match_config: + mismatch_configs.append((shape, estimate_config)) + + correctness_count += int(match_config) + print(f"{shape}, {estimate_config}, {best_test_config[shape]}, {match_config}") +print("% of predictions match nvfuser parameters", correctness_count / len(test_shapes)) +print(correctness_count, "out of", len(test_shapes)) + +print("======================= compare performance =========================") + +for shape, estimate_config in mismatch_configs: + nvf_inputs, eager_inputs = create_inputs(shape) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, est_perf = run_profile(presched_fd, nvf_inputs, eager_inputs, estimate_config) + _, nvf_perf = run_profile(presched_fd, nvf_inputs, eager_inputs) + est_perf_faster = est_perf < nvf_perf + print( + f"{shape} \t estimate_perf:{est_perf:.5f} \t nvfuser_perf:{nvf_perf:.5f} \t is_estimated_config_faster:\t{est_perf_faster}" + ) + +print("=====================================================================") + +# For a specific batch size, gather performance across a range of hidden sizes. +# Calculate performance for best predicted and nvfuser configurations. Plot a +# chart comparing performance using matplotlib. + +# NOTE: The matplotlib experiment plots the kernel runtime, which could be +# different than the selected performance metric. Currently, the performance +# metric is effective_bandwidth_gbs. + +import matplotlib.pyplot as plt +import numpy as np + +# Avoid reusing any cached, user-scheduled fusions to have a clean run. +FusionCache.reset() +est_perfs = [] +for hidden_shape in empirical_hidden_sizes: + nvf_inputs, eager_inputs = create_inputs((empirical_batch_size, hidden_shape)) + estimate_config = find_best_parameters( + clf, (empirical_batch_size, hidden_shape), parameter_configurations + ) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, est_time_ms = run_profile(presched_fd, nvf_inputs, eager_inputs, estimate_config) + est_perfs.append(est_time_ms) + print( + f"decision tree: {empirical_batch_size}, {hidden_shape}, {estimate_config}, {est_time_ms:.3f}" + ) + +FusionCache.reset() +nvf_perfs = [] +for hidden_shape in empirical_hidden_sizes: + nvf_inputs, eager_inputs = create_inputs((empirical_batch_size, hidden_shape)) + estimate_config = find_best_parameters( + clf, (empirical_batch_size, hidden_shape), parameter_configurations + ) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, nvf_time_ms = run_profile(presched_fd, nvf_inputs, eager_inputs) + nvf_perfs.append(nvf_time_ms) + print(f"nvfuser: {empirical_batch_size}, {hidden_shape}, {nvf_time_ms:.3f}") + +# Get mean speed-up from nvfuser to empirical configurations across all input shapes. +# Negative value mean empirical configurations are slower than nvfuser. +print("Mean speed-up", np.mean(np.array(nvf_perfs) - np.array(est_perfs))) + +np_hidden_size = np.array(empirical_hidden_sizes) +plt.plot(np_hidden_size, np.array(est_perfs)) +plt.plot(np_hidden_size, np.array(nvf_perfs)) + +plt.xlabel("Hidden Size") +plt.ylabel("Time(ms)") +plt.title( + f"Batch Size = {empirical_batch_size}, Compare Decision Tree Heuristic vs NvFuser" +) +plt.legend(["decision_tree", "nvfuser"], loc="lower right") +plt.savefig(f"persistent_inner_outer_empirical_batchsize{empirical_batch_size}.png") + +# ============================================================================= From d88dcba805c0ea5a72e94c78ae03df32f6b3a2c2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 30 Oct 2024 08:03:05 -0400 Subject: [PATCH 13/17] Disable nvfusertest_serde_check if DEBUG_SERDE=disable (#3304) This is another attempt to fix the codediff CI job Fixes #3265. Fixes #3283. --- tests/python/utils.py | 20 +++++++++++++++++--- tools/codediff/compare_codegen.sh | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/python/utils.py b/tests/python/utils.py index 4cb0c0e4cb3..2a7fadc4a14 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -274,6 +274,7 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None): # This DEBUG_SERDE environment flag is used to debug serialization failures. # +# If DEBUG_SERDE=debug # 1) It disables automatically saving FusionCache upon program exit. Therefore, # it has to be a global flag not per-test. # @@ -283,8 +284,14 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None): # # 3) It keeps the temporary files that are created during serde_check. # Normally, these files are deleted after each test. -env_var_debug_serde = os.getenv("DEBUG_SERDE") -debug_serde: bool = env_var_debug_serde in ("true", "1") +# +# DEBUG_SERDE=disable +# 1) It disables the @nvfusertest_serde_check decorator. This disables checking +# that serde round-trips preserve the definition during testing. +env_var_debug_serde = os.getenv("DEBUG_SERDE", "").lower() +debug_serde: bool = env_var_debug_serde == "debug" +disable_serde: bool = env_var_debug_serde == "disable" +del env_var_debug_serde # The pytest framework and test_python_frontend.py use different arguments for @@ -314,7 +321,7 @@ def basic_serde_check(): ) else: raise RuntimeError( - "***** Use DEBUG_SERDE=true to debug serialization failure." + "***** Use DEBUG_SERDE=debug to debug serialization failure." ) @@ -323,6 +330,11 @@ def basic_serde_check(): # binary. Call FusionCache.reset() to clear the cache after running an error # test in `test_python_frontend.py'. def atexit_serde_check(): + if disable_serde: + # Ignore FusionCache and automatic serialization if serde check is + # disabled + return + from nvfuser import FusionCache if not debug_serde: @@ -343,6 +355,8 @@ def nvfusertest_serde_check(test_fn: Callable): function. Currently, it uses serialization to rebuild the FusionCache structure. """ + if disable_serde: + return test_fn def inner_fn(*args, **kwargs): self, fusion_func, inputs = args diff --git a/tools/codediff/compare_codegen.sh b/tools/codediff/compare_codegen.sh index 8ae33f8805c..478936a047d 100755 --- a/tools/codediff/compare_codegen.sh +++ b/tools/codediff/compare_codegen.sh @@ -189,7 +189,7 @@ collect_kernels() { # Make tests reproducible export NVFUSER_TEST_RANDOM_SEED=0 export NVFUSER_DISABLE=parallel_compile - export DEBUG_SERDE=true + export DEBUG_SERDE=disable # run tests and benchmarks with cuda_to_file and dump output to files mkdir -p "$outdir/$commit" From f394b4e382c03e9fcec55d31c584d12f13c4e1de Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:46:40 -0400 Subject: [PATCH 14/17] reorder outer reduction tv in inner-outer scheduler when there are view ops in the fusion (#3287) **Issue**: The IDs involved in view transforms are moved to the outer most positions in the loop domain, e.g. `T0[i0, i2, i3] --> T3[i0, i2*i3]`, `propagateReshapeTransforms` moves `{i2*i3}` to the outer most position and the loop domain becomes ` (iS31{( i2 * i3 )}, iS0{i0})`. To maintain the original reduction axis, for reduction tv we should reorder the loop domain back to its original logical domain, this is done for inner reduction in innerOuter scheduler but not for outer reduciton. (1) Inner reduction tv after `propagateReshapeTransforms` ``` T3_l_float[ rS13{( i2 * i3 )}, iS12{i0} ] logical domain : (iS12{i0}, rS13{( i2 * i3 )}) contiguity: t n loop domain : (rS13{( i2 * i3 )}, iS12{i0}) ``` (2) Inner reduction tv after `reorder(domainReorderAsLogicalMap)` ``` T3_l_float[ iS12{i0}, rS13{( i2 * i3 )} ] logical domain : (iS12{i0}, rS13{( i2 * i3 )}) contiguity: t n loop domain : (iS12{i0}, rS13{( i2 * i3 )}) ``` (3) Outer reduction tv after `propagateReshapeTransforms` ``` T6_l_float[ iS19{( i2 * i3 )}, rS18{i0} ] logical domain : (rS18{i0}, iS19{( i2 * i3 )}) contiguity: n t loop domain : (iS19{( i2 * i3 )}, rS18{i0}) ``` `reorder(domainReorderAsLogicalMap)` is not used for Outer reduction tv. This leads to error `Cannot rfactor axes that are not reduction axes.` when the scheduler tries to rFactor outer dim, which is `iS19{( i2 * i3 )` **Fix**: Add `reorder(domainReorderAsLogicalMap)` for outer reduction tv. **Results**: Added a unit test, err is fixed. Outer reduction tv is correctly reordered as: ``` T6_l_float[ rS18{i0}, iS19{( i2 * i3 )} ] logical domain : (rS18{i0}, iS19{( i2 * i3 )}) contiguity: n t loop domain : (rS18{i0}, iS19{( i2 * i3 )}) ``` --- csrc/scheduler/normalization_inner_outer.cpp | 9 ++++ .../test_combined_inner_outer_reduction.cpp | 46 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 2ea854f0a88..e4bbc803033 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -815,6 +815,15 @@ void scheduleReductionCombinedOuter( } }; for (auto& outer_reduction_tv : outer_reduction_tvs) { + // Similar to the inner reduction, we need to reorder the outer reduction tv + // when there are view operations. + if (!ir_utils::getViewOps(fusion).empty()) { + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + outer_reduction_tv->reorder( + scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); + } + // merge tensorview to [reduction, iteraiton] domains mergeReductionOrIterDomains(outer_reduction_tv, true); mergeReductionOrIterDomains(outer_reduction_tv, false); diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 2071aeb0e86..95eaadd4ad7 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -994,4 +994,50 @@ TEST_F(CombinedSchedulerTest, SharedMemoryPersistentVectFactor) { aten_inputs, heuristic_params->as()->lparams); testValidate(&fusion_copy, cg_outputs, aten_inputs, __LINE__, __FILE__); } + +using InnerOuterReshapeTest = NVFuserFixtureParamTest; +INSTANTIATE_TEST_SUITE_P( + , + InnerOuterReshapeTest, + testing::Bool(), + testing::PrintToStringParamName()); +TEST_P(InnerOuterReshapeTest, ReshapeOuterDimTrueOrFalse) { + auto reshape_outer_dim = GetParam(); + Fusion fusion; + FusionGuard fg(&fusion); + // reshape a 3D input tensor to 2D + // [4, 1024, 4096] -> [4096, 4096] + // [4096, 4, 1024] -> [4096, 4096] + const int dim0 = reshape_outer_dim ? 4 : 4096; + const int dim1 = reshape_outer_dim ? 1024 : 4; + const int dim2 = reshape_outer_dim ? 4096 : 1024; + auto dtype = DataType::Half; + auto tv0 = makeContigTensor(3, dtype); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + + auto tv4 = reshape(tv1, {dim0, dim1, dim2}, {4096, 4096}); + + auto tv5 = sum(tv4, {1}); + auto tv6 = broadcast(tv5, {false, true}); + auto tv7 = add(tv6, tv4); + auto tv8 = sum(tv4, {0}); + auto tv9 = castOp(DataType::Half, tv7); + auto tv10 = castOp(DataType::Half, tv8); + fusion.addOutput(tv9); + fusion.addOutput(tv10); + + Fusion fusion_copy = fusion; + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dim0, dim1, dim2}, options); + std::vector aten_inputs = {t0}; + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::InnerOuterPersistent, aten_inputs); + auto persistent_params = cg_results.heuristic_params->as(); + ASSERT_FALSE(persistent_params->project_persistent_buffers); + testValidate( + &fusion_copy, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser From a4465df112ea6ecdb9dc47cb1cc4e8c2ffa3e162 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:20:26 -0700 Subject: [PATCH 15/17] Host benchmarking for a fusion with multiple segments (#3307) This benchmark uses matmul + pointwise op to create a fusion with 12 segments instead of using `segment_set` to force segmentation. ![Screenshot 2024-10-29 at 4 41 47 PM](https://github.com/user-attachments/assets/2e65f8b9-489b-431b-8694-ab265f90ce32) For `host_benchmark_mode='compile'`, the profile is shown below. The `Finding valid segment solutions` pass takes 52 ms ![Screenshot 2024-10-29 at 4 52 38 PM](https://github.com/user-attachments/assets/eb02a309-2b17-4c7b-8a89-325a462323cb) --- benchmarks/python/test_many_segments_host.py | 83 ++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 benchmarks/python/test_many_segments_host.py diff --git a/benchmarks/python/test_many_segments_host.py b/benchmarks/python/test_many_segments_host.py new file mode 100644 index 00000000000..9515da141a0 --- /dev/null +++ b/benchmarks/python/test_many_segments_host.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +import pytest +from nvfuser import FusionDefinition, DataType +from .core import run_benchmark +import torch + + +def many_matmul_fusion(fd: FusionDefinition) -> None: + x = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False + ) + y = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False + ) + a = fd.ops.add(x, y) + for _ in range(5): + a_transpose = fd.ops.permute(a, [1, 0]) + matmul_out = fd.ops.matmul(a_transpose, y) + add_out = fd.ops.add(a_transpose, y) + a = fd.ops.add(matmul_out, add_out) + fd.add_output(a) + + +@pytest.mark.parametrize("host_bench_mode", ["compile", "steady", "dynamic"]) +def test_many_segment_benchmark( + benchmark, + host_bench_mode: str, + disable_validation: bool, + disable_benchmarking: bool, +): + inputs = [torch.randn(16, 16, device="cuda", dtype=torch.float) for _ in range(2)] + + # Generate multiple inputs to measure dynamic shape overhead. + if host_bench_mode == "dynamic": + input_sizes = [4, 8, 16, 32, 64, 128] + # Generate matrices of size x size dimensions + inputs = [ + [ + torch.randn(size, size, device="cuda", dtype=torch.float) + for _ in range(2) + ] + for size in input_sizes + ] + + with FusionDefinition() as fd: + many_matmul_fusion(fd) + + def validate(input): + x, y = input + eager_output = x + y + for _ in range(5): + eager_transpose = eager_output.t() + matmul_out = torch.matmul(eager_transpose, y) + add_out = eager_transpose + y + eager_output = matmul_out + add_out + fd.validate(input, [eager_output]) + + # Validate number of segments + _ = fd.execute(input, profile=True) + num_segments = fd.profile().segments + expected_segments = 12 + assert ( + num_segments == expected_segments + ), f"Expected {expected_segments} fusion segments, got {num_segments}." + + if not disable_validation: + if host_bench_mode == "dynamic": + # Run validate for all input sizes. + for input in inputs: + validate(input) + else: + validate(inputs) + + if not disable_benchmarking: + run_benchmark( + benchmark, + None, + inputs, + device=f"host:{host_bench_mode}", + fusion_fn=many_matmul_fusion, + ) From 81dd1d288cf4fda112cc3488b30f33e90fb23ae6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:49:25 -0400 Subject: [PATCH 16/17] Use deep evaluation of extents in remove_empty pass (#3301) For dynamic fusions, we detect empty tensors and set their extents to immediate constant 0. Later, in the remove_empty preseg pass, we do a shallow check that extents are empty so that we can simplify the fusion. When the fusion is not dynamic there is no concretization step where we would do this extent replacement, so we might have constant 0 extents that are compound scalars. This caused us to miss some empty tensors in #3292, particularly one of the inputs to a `cat`. This PR: - Uses a deep evaluation of each `getMaybeExpandedExtent()` to determine if an axis is empty - Adds an ExpressionEvaluator field to `EmptyTensorRemover` to avoid repeating the deep evaluation when possible. This won't help prevent repeated evaluation of symbolic extents; we could track those in an `unordered_set` potentially instead. Fixes #3292 --------- Co-authored-by: Naoya Maruyama --- csrc/preseg_passes/remove_empty.cpp | 63 +++++++++++++---------- csrc/serde/fusion_record.cpp | 3 +- tests/python/test_python_frontend.py | 74 ++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 27 deletions(-) diff --git a/csrc/preseg_passes/remove_empty.cpp b/csrc/preseg_passes/remove_empty.cpp index 7893d993ded..0be346ee71e 100644 --- a/csrc/preseg_passes/remove_empty.cpp +++ b/csrc/preseg_passes/remove_empty.cpp @@ -7,10 +7,12 @@ // clang-format on #include +#include #include #include #include #include +#include #include #include @@ -21,29 +23,6 @@ namespace nvfuser::preseg_passes { namespace { -//! Get a vector of the integer positions of constant zero extent axes in the -//! input domain. This will typically be used like -//! `emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain()))` -std::vector emptyAxes(const std::vector& domain) { - std::vector empty_axes; - for (auto ax : c10::irange(domain.size())) { - auto id = domain.at(ax); - if (id->getMaybeExpandedExtent()->isConst() && - id->getMaybeExpandedExtent()->evaluate().as() == 0) { - empty_axes.push_back((int64_t)ax); - } - } - return empty_axes; -} - -//! Check whether a TensorView is empty. During concretization, we traverse to -//! find a minimal set of TensorViews that have zero extents, and we then set -//! their extents to a constant 0. Here we check for those constant zero -//! extents. -bool isTVEmpty(TensorView* tv) { - return !emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain())).empty(); -} - //! EmptyTensorRemover performs a backward traversal of the Fusion. When it //! detects a TensorView that has at least one extent that is zero, we do the //! following: @@ -69,9 +48,34 @@ class EmptyTensorRemover : public DeadCodeRemover { public: EmptyTensorRemover(Fusion* fusion) : DeadCodeRemover(fusion) {} - protected: + private: using DeadCodeRemover::handle; + //! Get a vector of the integer positions of constant zero extent axes in the + //! input domain. This will typically be used like + //! `emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain()))` + std::vector emptyAxes(const std::vector& domain) { + std::vector empty_axes; + for (auto ax : c10::irange(domain.size())) { + auto id = domain.at(ax); + PolymorphicValue extent = + expr_eval_.evaluate(id->getMaybeExpandedExtent()); + if (extent.hasValue() && extent.as() == 0) { + empty_axes.push_back((int64_t)ax); + } + } + return empty_axes; + } + + //! Check whether a TensorView is empty. During concretization, we traverse to + //! find a minimal set of TensorViews that have zero extents, and we then set + //! their extents to a constant 0. Here we check for those constant zero + //! extents. + bool isTVEmpty(TensorView* tv) { + return !emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain())) + .empty(); + } + //! If tv is a fusion output, we check whether it is empty and if so, replace //! it with full(). For non-outputs that are not inputs, we simply check that //! the tensor is not provably empty. @@ -257,8 +261,9 @@ class EmptyTensorRemover : public DeadCodeRemover { "Inputs to CatOp must be outputs of PadOps"); auto tv = inp->definition()->as()->in()->as(); auto cat_id = TensorDomain::noReductions(tv->getLogicalDomain()).at(dim); - if (cat_id->getMaybeExpandedExtent()->isConst() && - cat_id->getMaybeExpandedExtent()->evaluate().as() == 0) { + PolymorphicValue extent = + expr_eval_.evaluate(cat_id->getMaybeExpandedExtent()); + if (extent.hasValue() && extent.as() == 0) { continue; } non_empty_inputs.push_back(tv); @@ -312,6 +317,12 @@ class EmptyTensorRemover : public DeadCodeRemover { registerReplacement(out, new_tv); } } + + private: + // We use this ExpressionEvaluator without binding any inputs. This lets us + // quickly repeatedly evaluate compound constant expressions like + // ( fmax(0, ( fmin(( ceilDiv(576, 9) ), 0) )) ) + ExpressionEvaluator expr_eval_; }; } // namespace diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index f40edaaea44..5de2cda9873 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -43,7 +43,8 @@ python_frontend::RecordFunctor* deserializeOpRecord( const RecordFunctor* buffer) { NVF_ERROR( str_to_func_map.find(buffer->name()->str()) != str_to_func_map.end(), - "Missing mapping from operation string to nvfuser function in serde deserialization."); + "Missing mapping from operation string to nvfuser function in serde deserialization: ", + buffer->name()->str()); return new python_frontend::OpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 8080c48278c..874223471eb 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4600,3 +4600,77 @@ def fusion_func(fd: FusionDefinition) -> None: nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) for out in nvf_out: self.assertTrue(out.allclose(x[:, 1:, 2:])) + + def test_issue_3292(self): + inputs = [ + torch.testing.make_tensor( + (5, 5, 576), dtype=torch.float32, device="cuda:0" + ), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T2 = fd.define_tensor( + shape=[5, 5, 576], + contiguity=[True, True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T30 = fd.ops.reshape(T2, new_shape=[5, 5, 1, 9, 64]) + T31 = fd.ops.permute(T30, dims=[0, 2, 3, 1, 4]) + T50 = fd.ops.slice( + T31, + start_indices=[0, 0, 0, 0, 0], + end_indices=[5, 1, 7, 5, 64], + strides=[1, 1, 1, 1, 1], + manual_normalization=0, + ) + T108 = fd.ops.reshape(T50, new_shape=[5, 7, 5, 64]) + T136 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 0], + end_indices=[5, 7, 5, 32], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T152 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 32], + end_indices=[5, 7, 5, 64], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T153 = fd.ops.neg(T152) + T154 = fd.ops.cat([T153, T136], dim=-1, manual_padding=0) + T161 = fd.ops.mul(T108, T108) + T168 = fd.ops.mul(T154, T154) + T169 = fd.ops.add(T161, T168) + T185 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 0], + end_indices=[5, 7, 5, 32], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T201 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 32], + end_indices=[5, 7, 5, 64], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T202 = fd.ops.neg(T201) + T203 = fd.ops.cat([T202, T185], dim=-1, manual_padding=0) + T205 = fd.ops.mul(T203, T203) + T222 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 0], + end_indices=[5, 7, 5, 0], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0) + fd.add_output(T223) + + # is_clonable=False is because translation fails with missing ceilDiv + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False) From bad9e50bc9539054050310f423317b4c1d259c53 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 30 Oct 2024 13:57:59 -0700 Subject: [PATCH 17/17] Factorize ExpressionEvaluator::bind_. (#3305) No functionality changes. --- csrc/expr_evaluator.cpp | 143 +++++++++++++++++++++------------------- csrc/expr_evaluator.h | 18 +++-- 2 files changed, 86 insertions(+), 75 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 0fd62022098..d4ca6daa022 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -129,6 +129,79 @@ void validateValWithConcreteValue( } // namespace +void ExpressionEvaluator::bindTensorDomain( + const TensorView* tv, + const at::Tensor& t, + const bool evaluate_validate) { + auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); + NVF_ERROR( + t.dim() == (int64_t)logical_domain.size(), + "Expected ", + getInputPosString(tv), + tv->toString(), + ", to be bound to a tensor of rank ", + logical_domain.size(), + ", but got a tensor of rank ", + t.dim()); + for (auto i : c10::irange(t.dim())) { + auto id = logical_domain[i]; + if (id->isBroadcast()) { + // DIDs are ignored for broadcast. + bind_(logical_domain[i]->extent(), 1, evaluate_validate); + if (id->hasExpandedExtent()) { + // Verify that t is also expanded + NVF_ERROR( + t.size(i) == 1 || t.stride(i) == 0, + "IterDomain ", + id->toString(), + " in ", + getInputPosString(tv), + "TensorView ", + tv->toString(), + " has expanded extent but input tensor has size ", + t.size(i), + " and stride ", + t.stride(i), + " in dimension ", + i); + bind_( + logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); + } + } else { + if (logical_domain[i]->isDeviceDim()) { + // Currently we have the restrictions: + // (1) Devices parallelized axis extent == DeviceMesh's extent + // (2) Device parallelized axis cannot be split or merged + // Therefore, the device parallelized extents will always be allocated + // with size 1, but the symbolic axis extent is binded with the extent + // of the DeviceMesh + NVF_CHECK( + 1 == t.size(i), + "TensorView ", + tv->toString(), + getInputPosString(tv), + " IterDomain ", + id->toString(), + "is sharded and must have size 1, but input tensor has size ", + t.size(i)); + NVF_CHECK( + tv->hasDeviceMesh(), + "TV ", + tv->toString(), + getInputPosString(tv), + " has an empty DeviceMesh with DID parallelization") + bind_( + logical_domain[i]->extent(), + static_cast( + tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), + evaluate_validate); + } else { + bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + } + } + } +} + void ExpressionEvaluator::bind_( const Val* value, PolymorphicValue concrete_value, @@ -162,75 +235,7 @@ void ExpressionEvaluator::bind_( } if (auto tv = dynamic_cast(value)) { const auto& t = concrete_value.as(); - auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); - NVF_ERROR( - t.dim() == (int64_t)logical_domain.size(), - "Expected ", - getInputPosString(tv), - tv->toString(), - ", to be bound to a tensor of rank ", - logical_domain.size(), - ", but got a tensor of rank ", - t.dim()); - for (auto i : c10::irange(t.dim())) { - auto id = logical_domain[i]; - if (id->isBroadcast()) { - // DIDs are ignored for broadcast. - bind_(logical_domain[i]->extent(), 1, evaluate_validate); - if (id->hasExpandedExtent()) { - // Verify that t is also expanded - NVF_ERROR( - t.size(i) == 1 || t.stride(i) == 0, - "IterDomain ", - id->toString(), - " in ", - getInputPosString(tv), - "TensorView ", - tv->toString(), - " has expanded extent but input tensor has size ", - t.size(i), - " and stride ", - t.stride(i), - " in dimension ", - i); - bind_( - logical_domain[i]->expandedExtent(), - t.size(i), - evaluate_validate); - } - } else { - if (logical_domain[i]->isDeviceDim()) { - // Currently we have the restrictions: - // (1) Devices parallelized axis extent == DeviceMesh's extent - // (2) Device parallelized axis cannot be split or merged - // Therefore, the device parallelized extents will always be allocated - // with size 1, but the symbolic axis extent is binded with the extent - // of the DeviceMesh - NVF_CHECK( - 1 == t.size(i), - "TensorView ", - tv->toString(), - getInputPosString(tv), - " IterDomain ", - id->toString(), - "is sharded and must have size 1, but input tensor has size ", - t.size(i)); - NVF_CHECK( - tv->hasDeviceMesh(), - "TV ", - tv->toString(), - getInputPosString(tv), - " has an empty DeviceMesh with DID parallelization") - bind_( - logical_domain[i]->extent(), - static_cast(tv->getDeviceMesh().size( - logical_domain[i]->getParallelType())), - evaluate_validate); - } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); - } - } - } + bindTensorDomain(tv, t, evaluate_validate); } if (value->isA()) { known_named_scalars_[value->as()->name()] = diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index ef1114bb8ff..b6c8e1857ea 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -25,12 +25,6 @@ class PrecomputedValues; //! Calculate Fusion IR expressions class ExpressionEvaluator { - NVF_API void bind_( - const Val* value, - PolymorphicValue concrete_value, - bool evaluate_validate); - void bind_(const std::string& name, PolymorphicValue concrete_value); - public: //! Bind a concrete value to an IR variable //! If evaluate_validate is true, and value is evaluatable with the @@ -98,6 +92,18 @@ class ExpressionEvaluator { ExpressionEvaluator clone(IrCloner& ir_cloner) const; private: + void bind_( + const Val* value, + PolymorphicValue concrete_value, + bool evaluate_validate); + + void bind_(const std::string& name, PolymorphicValue concrete_value); + + void bindTensorDomain( + const TensorView* tv, + const at::Tensor& t, + bool evaluate_validate); + const PolymorphicValue& getValue( const Val* value, const std::unordered_map&