Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add transpose for allgather mm #1275

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions backends/npu/custom_op/fused_allgather_mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,52 @@
#include "kernels/funcs/npu_op_runner.h"
#include "paddle/extension.h"

static inline aclTensor* Create_Acltensor(const phi::DenseTensor& paddletensor,
const bool transpose) {
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
auto tensor_dtype = paddletensor.dtype();
auto acl_data_type = ConvertToNpuDtype(tensor_dtype);
const auto dimNum =
paddletensor.dims().size() == 0 ? 1 : paddletensor.dims().size();
std::vector<int64_t> storageDims(dimNum - 1);
storageDims.push_back(paddletensor.numel() * sizeof(tensor_dtype));
aclFormat format = ACL_FORMAT_ND;
switch (dimNum) {
break;
case 4:
format = ACL_FORMAT_NCHW;
break;
case 5:
format = ACL_FORMAT_NCDHW;
break;
default:
format = ACL_FORMAT_ND;
}
auto shape = phi::vectorize(paddletensor.dims());
auto strides = shape;
if (!strides.empty()) {
strides.erase(strides.begin());
}
strides.push_back(1);
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; i--) {
strides[i] = strides[i] * strides[i + 1];
}
if (transpose) {
std::swap(shape[shape.size() - 1], shape[shape.size() - 2]);
std::swap(strides[strides.size() - 1], strides[strides.size() - 2]);
}
auto acl_tensor = aclCreateTensor(shape.data(),
dimNum,
acl_data_type,
strides.data(),
0,
format,
shape.data(),
dimNum,
const_cast<void*>(paddletensor.data()));
return acl_tensor;
}

int64_t GetShapeSize(const std::vector<int64_t>& shape) {
int64_t shapeSize = 1;
for (auto i : shape) {
Expand Down Expand Up @@ -55,18 +101,32 @@ int CreateZeroDimAclTensor(const std::vector<float>& hostData,
const phi::DDim get_output_size_gather_mm(const paddle::Tensor& x1,
const paddle::Tensor& x2,
int64_t world_size,
int64_t gather_index) {
int64_t gather_index,
const bool transpose_y) {
auto out_x = gather_index == 0 ? x1.dims()[0] * world_size : x1.dims()[0];
auto out_y = x2.dims()[1];
return {out_x, out_y};
if (transpose_y) {
auto out_y = x2.dims()[0];
return {out_x, out_y};
} else {
auto out_y = x2.dims()[1];
return {out_x, out_y};
}
}

const phi::DDim get_output_size_gather(const paddle::Tensor& x1,
const paddle::Tensor& x2,
int64_t world_size,
int64_t gather_index) {
const paddle::Tensor& gather_out = gather_index == 0 ? x1 : x2;
return {gather_out.dims()[0] * world_size, gather_out.dims()[1]};
int64_t gather_index,
const bool transpose_y) {
if (gather_index == 0) {
return {x1.dims()[0] * world_size, x1.dims()[1]};
} else {
if (transpose_y) {
return {x2.dims()[1] * world_size, x2.dims()[0]};
} else {
return {x2.dims()[0] * world_size, x2.dims()[1]};
}
}
}

std::vector<paddle::Tensor> npu_allgather_mm(
Expand All @@ -77,27 +137,37 @@ std::vector<paddle::Tensor> npu_allgather_mm(
int64_t world_size,
int64_t gather_index,
bool gather_output,
int64_t comm_turn) {
int64_t comm_turn,
const bool transpose_y) {
PD_CHECK(
x1.dims().size() == 2 && x2.dims().size() == 2,
"Both inputs of mm are required to be 2D, but the actual inputs are ",
x1.dims().size(),
"D and ",
x2.dims().size(),
"D");
PD_CHECK(x1.dims()[1] == x2.dims()[0],
"The K-axis in the two inputs of Matmul must be equal, but in "
"reality, the K-axis of x1 is ",
x1.dims()[1],
" and the K-axis of x2 is ",
x2.dims()[0]);
if (transpose_y) {
PD_CHECK(x1.dims()[1] == x2.dims()[1],
"The K-axis in the two inputs of Matmul must be equal, but in "
"reality, the K-axis of x1 is ",
x1.dims()[1],
" and the K-axis of x2 is ",
x2.dims()[1]);
} else {
PD_CHECK(x1.dims()[1] == x2.dims()[0],
"The K-axis in the two inputs of Matmul must be equal, but in "
"reality, the K-axis of x1 is ",
x1.dims()[1],
" and the K-axis of x2 is ",
x2.dims()[0]);
}

auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(x1.place()));
auto out_gather_mm_size =
get_output_size_gather_mm(x1, x2, world_size, gather_index);
get_output_size_gather_mm(x1, x2, world_size, gather_index, transpose_y);
auto out_gather_size =
get_output_size_gather(x1, x2, world_size, gather_index);
get_output_size_gather(x1, x2, world_size, gather_index, transpose_y);

std::shared_ptr<phi::DenseTensor> out_gather_mm =
std::make_shared<phi::DenseTensor>();
Expand All @@ -117,6 +187,8 @@ std::vector<paddle::Tensor> npu_allgather_mm(

auto x1_tensor = *(static_cast<const phi::DenseTensor*>(x1.impl().get()));
auto x2_tensor = *(static_cast<const phi::DenseTensor*>(x2.impl().get()));

aclTensor* y_acltensor = Create_Acltensor(x2_tensor, transpose_y);
char* hcom_ptr = const_cast<char*>(hcom.data());
aclTensor* out_gather_zerotensor = nullptr;
if (gather_output) {
Expand All @@ -127,7 +199,7 @@ std::vector<paddle::Tensor> npu_allgather_mm(
EXEC_NPU_CMD(aclnnAllGatherMatmul,
*dev_ctx,
x1_tensor,
x2_tensor,
y_acltensor,
*bias_real,
hcom_ptr,
gather_index,
Expand Down Expand Up @@ -160,7 +232,7 @@ std::vector<paddle::Tensor> npu_allgather_mm(
EXEC_NPU_CMD(aclnnAllGatherMatmul,
*dev_ctx,
x1_tensor,
x2_tensor,
y_acltensor,
*bias_real,
hcom_ptr,
gather_index,
Expand Down Expand Up @@ -188,7 +260,8 @@ PD_BUILD_OP(fused_allgather_mm)
"world_size:int64_t",
"gather_index:int64_t",
"gather_output:bool",
"comm_turn:int64_t"})
"comm_turn:int64_t",
"transpose_y:bool"})
.SetKernelFn(PD_KERNEL(npu_allgather_mm))
.SetInferShapeFn(PD_INFER_SHAPE(
FusedAllgatherMMInferShape)); // neccessary if the op has muti_inputs