From f45b793f9ba78a97ea8da1882f9d31c4038b6ab0 Mon Sep 17 00:00:00 2001 From: zhenwenqi Date: Fri, 31 May 2024 11:30:29 +0800 Subject: [PATCH] add transpose for allgather mm --- backends/npu/custom_op/fused_allgather_mm.cc | 109 ++++++++++++++++--- 1 file changed, 91 insertions(+), 18 deletions(-) diff --git a/backends/npu/custom_op/fused_allgather_mm.cc b/backends/npu/custom_op/fused_allgather_mm.cc index c1a38c8b4..4e38f39cb 100644 --- a/backends/npu/custom_op/fused_allgather_mm.cc +++ b/backends/npu/custom_op/fused_allgather_mm.cc @@ -19,6 +19,52 @@ #include "kernels/funcs/npu_op_runner.h" #include "paddle/extension.h" +static inline aclTensor* Create_Acltensor(const phi::DenseTensor& paddletensor, + const bool transpose) { + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + auto tensor_dtype = paddletensor.dtype(); + auto acl_data_type = ConvertToNpuDtype(tensor_dtype); + const auto dimNum = + paddletensor.dims().size() == 0 ? 1 : paddletensor.dims().size(); + std::vector storageDims(dimNum - 1); + storageDims.push_back(paddletensor.numel() * sizeof(tensor_dtype)); + aclFormat format = ACL_FORMAT_ND; + switch (dimNum) { + break; + case 4: + format = ACL_FORMAT_NCHW; + break; + case 5: + format = ACL_FORMAT_NCDHW; + break; + default: + format = ACL_FORMAT_ND; + } + auto shape = phi::vectorize(paddletensor.dims()); + auto strides = shape; + if (!strides.empty()) { + strides.erase(strides.begin()); + } + strides.push_back(1); + for (int i = static_cast(strides.size()) - 2; i >= 0; i--) { + strides[i] = strides[i] * strides[i + 1]; + } + if (transpose) { + std::swap(shape[shape.size() - 1], shape[shape.size() - 2]); + std::swap(strides[strides.size() - 1], strides[strides.size() - 2]); + } + auto acl_tensor = aclCreateTensor(shape.data(), + dimNum, + acl_data_type, + strides.data(), + 0, + format, + shape.data(), + dimNum, + const_cast(paddletensor.data())); + return acl_tensor; +} + int64_t GetShapeSize(const std::vector& shape) { int64_t shapeSize = 1; for (auto i : shape) { @@ -55,18 +101,32 @@ int CreateZeroDimAclTensor(const std::vector& hostData, const phi::DDim get_output_size_gather_mm(const paddle::Tensor& x1, const paddle::Tensor& x2, int64_t world_size, - int64_t gather_index) { + int64_t gather_index, + const bool transpose_y) { auto out_x = gather_index == 0 ? x1.dims()[0] * world_size : x1.dims()[0]; - auto out_y = x2.dims()[1]; - return {out_x, out_y}; + if (transpose_y) { + auto out_y = x2.dims()[0]; + return {out_x, out_y}; + } else { + auto out_y = x2.dims()[1]; + return {out_x, out_y}; + } } const phi::DDim get_output_size_gather(const paddle::Tensor& x1, const paddle::Tensor& x2, int64_t world_size, - int64_t gather_index) { - const paddle::Tensor& gather_out = gather_index == 0 ? x1 : x2; - return {gather_out.dims()[0] * world_size, gather_out.dims()[1]}; + int64_t gather_index, + const bool transpose_y) { + if (gather_index == 0) { + return {x1.dims()[0] * world_size, x1.dims()[1]}; + } else { + if (transpose_y) { + return {x2.dims()[1] * world_size, x2.dims()[0]}; + } else { + return {x2.dims()[0] * world_size, x2.dims()[1]}; + } + } } std::vector npu_allgather_mm( @@ -77,7 +137,8 @@ std::vector npu_allgather_mm( int64_t world_size, int64_t gather_index, bool gather_output, - int64_t comm_turn) { + int64_t comm_turn, + const bool transpose_y) { PD_CHECK( x1.dims().size() == 2 && x2.dims().size() == 2, "Both inputs of mm are required to be 2D, but the actual inputs are ", @@ -85,19 +146,28 @@ std::vector npu_allgather_mm( "D and ", x2.dims().size(), "D"); - PD_CHECK(x1.dims()[1] == x2.dims()[0], - "The K-axis in the two inputs of Matmul must be equal, but in " - "reality, the K-axis of x1 is ", - x1.dims()[1], - " and the K-axis of x2 is ", - x2.dims()[0]); + if (transpose_y) { + PD_CHECK(x1.dims()[1] == x2.dims()[1], + "The K-axis in the two inputs of Matmul must be equal, but in " + "reality, the K-axis of x1 is ", + x1.dims()[1], + " and the K-axis of x2 is ", + x2.dims()[1]); + } else { + PD_CHECK(x1.dims()[1] == x2.dims()[0], + "The K-axis in the two inputs of Matmul must be equal, but in " + "reality, the K-axis of x1 is ", + x1.dims()[1], + " and the K-axis of x2 is ", + x2.dims()[0]); + } auto dev_ctx = static_cast( paddle::experimental::DeviceContextPool::Instance().Get(x1.place())); auto out_gather_mm_size = - get_output_size_gather_mm(x1, x2, world_size, gather_index); + get_output_size_gather_mm(x1, x2, world_size, gather_index, transpose_y); auto out_gather_size = - get_output_size_gather(x1, x2, world_size, gather_index); + get_output_size_gather(x1, x2, world_size, gather_index, transpose_y); std::shared_ptr out_gather_mm = std::make_shared(); @@ -117,6 +187,8 @@ std::vector npu_allgather_mm( auto x1_tensor = *(static_cast(x1.impl().get())); auto x2_tensor = *(static_cast(x2.impl().get())); + + aclTensor* y_acltensor = Create_Acltensor(x2_tensor, transpose_y); char* hcom_ptr = const_cast(hcom.data()); aclTensor* out_gather_zerotensor = nullptr; if (gather_output) { @@ -127,7 +199,7 @@ std::vector npu_allgather_mm( EXEC_NPU_CMD(aclnnAllGatherMatmul, *dev_ctx, x1_tensor, - x2_tensor, + y_acltensor, *bias_real, hcom_ptr, gather_index, @@ -160,7 +232,7 @@ std::vector npu_allgather_mm( EXEC_NPU_CMD(aclnnAllGatherMatmul, *dev_ctx, x1_tensor, - x2_tensor, + y_acltensor, *bias_real, hcom_ptr, gather_index, @@ -188,7 +260,8 @@ PD_BUILD_OP(fused_allgather_mm) "world_size:int64_t", "gather_index:int64_t", "gather_output:bool", - "comm_turn:int64_t"}) + "comm_turn:int64_t", + "transpose_y:bool"}) .SetKernelFn(PD_KERNEL(npu_allgather_mm)) .SetInferShapeFn(PD_INFER_SHAPE( FusedAllgatherMMInferShape)); // neccessary if the op has muti_inputs