Skip to content

Commit

Permalink
[Feature](mluOpKernels): access variable in tensor struct through fun…
Browse files Browse the repository at this point in the history
…ction in kernels
  • Loading branch information
nizhijie committed Dec 26, 2024
1 parent 8044af5 commit 16eea3d
Show file tree
Hide file tree
Showing 72 changed files with 890 additions and 626 deletions.
23 changes: 12 additions & 11 deletions core/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
cnrtGetLastError(); \
kernel; \
cnrtRet_t ret = cnrtPeekAtLastError(); \
if (MLUOP_PREDICT_FALSE(cnrtSuccess != ret)) { \
if (MLUOP_PREDICT_FALSE(cnrtSuccess != ret)) { \
LOG(ERROR) << "Check failed: Found " << cnrtGetErrorStr(ret) \
<< " after invoke kernel " #kernel; \
return MLUOP_STATUS_EXECUTION_FAILED; \
Expand Down Expand Up @@ -188,15 +188,15 @@
return MLUOP_STATUS_NOT_SUPPORTED; \
}

#define TENSOR_DIM_SIZE_CHECK(api, desc, max_num, reason, ...) \
for (int i = 0; i < desc->dim; i++) { \
if (!(desc->dims[i] < max_num)) { \
LOG(ERROR) << api << " overflow max supported tensor dim size " \
<< max_num - 1 << ", " \
<< "now tensor's dims[" << i << "] is " << desc->dims[i] \
<< ". " << reason; \
return MLUOP_STATUS_NOT_SUPPORTED; \
} \
#define TENSOR_DIM_SIZE_CHECK(api, desc, max_num, reason, ...) \
for (int i = 0; i < desc->getDim(); i++) { \
if (!(desc->getDimIndex(i) < max_num)) { \
LOG(ERROR) << api << " overflow max supported tensor dim size " \
<< max_num - 1 << ", " \
<< "now tensor's dims[" << i << "] is " \
<< desc->getDimIndex(i) << ". " << reason; \
return MLUOP_STATUS_NOT_SUPPORTED; \
} \
}

extern bool mluop_check_large_tensor_dim_size_;
Expand All @@ -222,7 +222,8 @@ extern bool mluop_check_large_tensor_dim_size_;
if (MLUOP_PREDICT_TRUE(desc != NULL)) { \
if (MLUOP_PREDICT_FALSE( \
MLUOP_PREDICT_TRUE(0 != mluOpGetTensorElementNum(desc)) && \
isStrideTensor(desc->dim, desc->dims, desc->strides))) { \
isStrideTensor(desc->getDim(), desc->getDims(), \
desc->getStrides()))) { \
LOG(ERROR) << api << " stride tensor is not supported. " << reason; \
return MLUOP_STATUS_NOT_SUPPORTED; \
} \
Expand Down
55 changes: 45 additions & 10 deletions core/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,20 +703,28 @@ mluOpStatus_t MLUOP_WIN_API mluOpSetTensorDescriptorOnchipDataType(
mluOpStatus_t MLUOP_WIN_API
mluOpSetTensorDescriptorPosition(mluOpTensorDescriptor_t desc, int position) {
PARAM_CHECK("[mluOpSetTensorDescriptorPosition]", desc != NULL);
return desc->setTensorDescriptorPosition(position);

desc->position = position;
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpSetTensorDescriptorPositionAndScale(
mluOpTensorDescriptor_t desc, int position, float scale) {
PARAM_CHECK("[mluOpSetTensorDescriptorPositionAndScale]", desc != NULL);
return desc->setTensorDescriptorPositionAndScale(position, scale);

desc->position = position;
desc->scale = scale;
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpSetTensorDescriptorPositionScaleAndOffset(
mluOpTensorDescriptor_t desc, int position, float scale, int offset) {
PARAM_CHECK("[mluOpSetTensorDescriptorPositionScaleAndOffset]", desc != NULL);
return desc->setTensorDescriptorPositionScaleAndOffset(position, scale,
offset);

desc->position = position;
desc->scale = scale;
desc->offset = offset;
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpSetTensorDescriptorPointerMode(
Expand Down Expand Up @@ -763,20 +771,37 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetTensorDescriptorOnchipDataType(
mluOpStatus_t MLUOP_WIN_API
mluOpGetTensorDescriptorPosition(mluOpTensorDescriptor_t desc, int *position) {
PARAM_CHECK("[mluOpGetTensorDescriptorPosition]", desc != NULL);
return desc->getTensorDescriptorPosition(position);
PARAM_CHECK("[mluOpGetTensorDescriptorPosition]", position != NULL);

*position = desc->position;
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpGetTensorDescriptorPositionAndScale(
mluOpTensorDescriptor_t desc, int *position, float *scale) {
PARAM_CHECK("[mluOpGetTensorDescriptorPositionAndScale]", desc != NULL);
return desc->getTensorDescriptorPositionAndScale(position, scale);
PARAM_CHECK("[mluOpGetTensorDescriptorPositionAndScale]", position != NULL);
PARAM_CHECK("[mluOpGetTensorDescriptorPositionAndScale]", scale != NULL);

*position = desc->position;
*scale = desc->scale;
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpGetTensorDescriptorPositionScaleAndOffset(
mluOpTensorDescriptor_t desc, int *position, float *scale, int *offset) {
PARAM_CHECK("[mluOpGetTensorDescriptorPositionScaleAndOffset]", desc != NULL);
return desc->getTensorDescriptorPositionScaleAndOffset(position, scale,
offset);
PARAM_CHECK("[mluOpGetTensorDescriptorPositionScaleAndOffset]",
position != NULL);
PARAM_CHECK("[mluOpGetTensorDescriptorPositionScaleAndOffset]",
scale != NULL);
PARAM_CHECK("[mluOpGetTensorDescriptorPositionScaleAndOffset]",
offset != NULL);

*position = desc->position;
*scale = desc->scale;
*offset = desc->offset;
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpGetTensorDescriptorPointerMode(
Expand All @@ -791,7 +816,17 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetTensorDescriptorPointerMode(
mluOpStatus_t MLUOP_WIN_API
mluOpDestroyTensorDescriptor(mluOpTensorDescriptor_t desc) {
PARAM_CHECK("[mluOpDestroyTensorDescriptor]", desc != NULL);
return desc->destroyTensorDescriptor();

#if MLUOP_TENSOR_QUEUE_ENABLE
queue_array.lock();
desc->~mluOpTensorStruct();
queue_array.queue.push_front(desc);
queue_array.unlock();
#else
delete desc;
#endif

return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API mluOpDestroyGroupTensorDescriptors(
Expand Down Expand Up @@ -950,4 +985,4 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetSeqDataDescriptorOnchipDataType(

*onchip_dtype = desc->onchip_dtype;
return MLUOP_STATUS_SUCCESS;
}
}
68 changes: 38 additions & 30 deletions core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ struct alignas(64) mluOpTensorStruct {
inline bool isCpuScalar() const;

public:
/* Offset - 52 */
/* To be removed*/
int position = 0;
float scale = 1;
int offset = 0;
std::vector<int> positions;
std::vector<float> scales;
std::vector<int> offsets;

inline mluOpTensorLayout_t getLayout() const { return this->layout; }
inline void setLayout(mluOpTensorLayout_t newLayout) {
this->layout = newLayout;
Expand Down Expand Up @@ -203,35 +212,35 @@ struct alignas(64) mluOpTensorStruct {
mluOpPointerMode_t *pointer_mode);

uint64_t getTensorElementNum() { return this->total_element_num; }
// private:
/* Try to pack and align the struct */
/* ------------------- 64 Bytes - 1 -------------------*/
int64_t normal_dims[MLUOP_DIM_MAX];

/* ------------------- 64 Bytes - 2 -------------------*/
int64_t normal_strides[MLUOP_DIM_MAX];

/* ------------------- 64 Bytes - 3 -------------------*/
/* Offset - 0 */
uint64_t total_element_num = 0;
uint64_t total_tensor_size = 0;
int64_t *dims = normal_dims; // point the normal dims as default
int64_t *strides = normal_strides; // point the normal strides as default
/* Offset - 32 */
int dim = 0;
mluOpDataType_t dtype = MLUOP_DTYPE_FLOAT;
mluOpDataType_t onchip_dtype = MLUOP_DTYPE_INVALID;
mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY;
mluOpPointerMode_t pointer_mode = MLUOP_POINTER_MODE_DEVICE;

/* Offset - 52 */
/* To be removed*/
int position = 0;
float scale = 1;
int offset = 0;
std::vector<int> positions;
std::vector<float> scales;
std::vector<int> offsets;
private:
/* Try to pack and align the struct */
/* ------------------- 64 Bytes - 1 -------------------*/
int64_t normal_dims[MLUOP_DIM_MAX];

/* ------------------- 64 Bytes - 2 -------------------*/
int64_t normal_strides[MLUOP_DIM_MAX];

/* ------------------- 64 Bytes - 3 -------------------*/
/* Offset - 0 */
uint64_t total_element_num = 0;
uint64_t total_tensor_size = 0;
int64_t *dims = normal_dims; // point the normal dims as default
int64_t *strides = normal_strides; // point the normal strides as default
/* Offset - 32 */
int dim = 0;
mluOpDataType_t dtype = MLUOP_DTYPE_FLOAT;
mluOpDataType_t onchip_dtype = MLUOP_DTYPE_INVALID;
mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY;
mluOpPointerMode_t pointer_mode = MLUOP_POINTER_MODE_DEVICE;

// /* Offset - 52 */
// /* To be removed*/
// int position = 0;
// float scale = 1;
// int offset = 0;
// std::vector<int> positions;
// std::vector<float> scales;
// std::vector<int> offsets;
};

// dim_set(rnn) [layer_num, direction, cap_of_cell]
Expand Down Expand Up @@ -419,7 +428,6 @@ inline int64_t mluOpGetTensordimD(const mluOpTensorDescriptor_t desc) {
}

inline int64_t mluOpGetTensordimC(const mluOpTensorDescriptor_t desc) {
switch (desc->getLayout()) {
switch (desc->getLayout()) {
case MLUOP_LAYOUT_NCHW:
case MLUOP_LAYOUT_NCDHW:
Expand Down
13 changes: 7 additions & 6 deletions kernels/abs/abs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ static mluOpStatus_t mluOpAbsParamCheck(mluOpHandle_t handle,
if (x_desc->getDimIndex(i) != y_desc->getDimIndex(i)) {
LOG(ERROR) << op_name << ":The shape of x should be equal to y"
<< ". But now x_desc's shape[" << i << "] is "
<< x_desc->getDimIndex(i) << ", y_desc's shape[" << i << "] is "
<< y_desc->getDimIndex(i) << ".";
<< x_desc->getDimIndex(i) << ", y_desc's shape[" << i
<< "] is " << y_desc->getDimIndex(i) << ".";
return MLUOP_STATUS_BAD_PARAM;
}
}
Expand Down Expand Up @@ -160,12 +160,13 @@ mluOpStatus_t MLUOP_WIN_API mluOpAbs(mluOpHandle_t handle,
mluop::getTensorShape(x_desc, &x_shape);
mluop::getTensorShape(y_desc, &y_shape);
CHECK_RETURN(op_name, Kernel3StagePipelineWithStrideAbs(
k_dim, k_type, handle->queue, x_desc->getDtype(), x,
x_shape, y, y_shape, dim_x));
k_dim, k_type, handle->queue, x_desc->getDtype(),
x, x_shape, y, y_shape, dim_x));
} else {
VLOG(5) << "kernel Kernel3StagePipelineAbs";
CHECK_RETURN(op_name, Kernel3StagePipelineAbs(k_dim, k_type, handle->queue,
x_desc->getDtype(), x, y, dim_x));
CHECK_RETURN(op_name,
Kernel3StagePipelineAbs(k_dim, k_type, handle->queue,
x_desc->getDtype(), x, y, dim_x));
}
GEN_CASE_END();
return MLUOP_STATUS_SUCCESS;
Expand Down
37 changes: 23 additions & 14 deletions kernels/active_rotated_filter/active_rotated_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,26 +64,35 @@ static mluOpStatus_t activeRotatedFilterForwardParamCheck(
PARAM_CHECK(api_name, output_desc->getDim() == 4);

// check dim
PARAM_CHECK(api_name, input_desc->getDimIndex(2) == indices_desc->getDimIndex(0));
PARAM_CHECK(api_name, input_desc->getDimIndex(3) == input_desc->getDimIndex(4));
PARAM_CHECK(api_name, input_desc->getDimIndex(3) == indices_desc->getDimIndex(1));
PARAM_CHECK(api_name, input_desc->getDimIndex(3) == output_desc->getDimIndex(2));
PARAM_CHECK(api_name, input_desc->getDimIndex(4) == indices_desc->getDimIndex(2));
PARAM_CHECK(api_name, input_desc->getDimIndex(4) == output_desc->getDimIndex(3));
PARAM_CHECK(api_name,
(input_desc->getDimIndex(2) > 0 && input_desc->getDimIndex(2) <= 128));
input_desc->getDimIndex(2) == indices_desc->getDimIndex(0));
PARAM_CHECK(api_name,
input_desc->getDimIndex(3) == input_desc->getDimIndex(4));
PARAM_CHECK(api_name,
input_desc->getDimIndex(3) == indices_desc->getDimIndex(1));
PARAM_CHECK(api_name,
input_desc->getDimIndex(3) == output_desc->getDimIndex(2));
PARAM_CHECK(api_name,
input_desc->getDimIndex(4) == indices_desc->getDimIndex(2));
PARAM_CHECK(api_name,
input_desc->getDimIndex(4) == output_desc->getDimIndex(3));
PARAM_CHECK(api_name, (input_desc->getDimIndex(2) > 0 &&
input_desc->getDimIndex(2) <= 128));
PARAM_CHECK_V2(api_name,
int(log(float(input_desc->getDimIndex(2))) / log(2.0f)) ==
log(float(input_desc->getDimIndex(2))) / log(2.0f),
"input_desc->getDimIndex(2) should be the power of 2.");
PARAM_CHECK(api_name, (input_desc->getDimIndex(3) == 3 || input_desc->getDimIndex(3) == 1));
PARAM_CHECK(api_name, (input_desc->getDimIndex(3) == 3 ||
input_desc->getDimIndex(3) == 1));
PARAM_CHECK(api_name, (indices_desc->getDimIndex(3) == 2 ||
indices_desc->getDimIndex(3) == 4 ||
indices_desc->getDimIndex(3) == 8));
PARAM_CHECK(api_name,
(output_desc->getDimIndex(0) ==
input_desc->getDimIndex(0) * indices_desc->getDimIndex(3)));
PARAM_CHECK(api_name,
(indices_desc->getDimIndex(3) == 2 || indices_desc->getDimIndex(3) == 4 ||
indices_desc->getDimIndex(3) == 8));
PARAM_CHECK(api_name, (output_desc->getDimIndex(0) ==
input_desc->getDimIndex(0) * indices_desc->getDimIndex(3)));
PARAM_CHECK(api_name, (output_desc->getDimIndex(1) ==
input_desc->getDimIndex(1) * input_desc->getDimIndex(2)));
(output_desc->getDimIndex(1) ==
input_desc->getDimIndex(1) * input_desc->getDimIndex(2)));

// check stride
STRIDE_TENSOR_CHECK(api_name + ":", input_desc,
Expand Down
15 changes: 10 additions & 5 deletions kernels/ball_query/ball_query.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,14 @@ mluOpStatus_t MLUOP_WIN_API mluOpBallQuery(
PARAM_CHECK("[mluOpBallQuery]", idx_desc->getDim() == 3);

// check dim0
PARAM_CHECK("[mluOpBallQuery]", new_xyz_desc->getDimIndex(0) == xyz_desc->getDimIndex(0));
PARAM_CHECK("[mluOpBallQuery]", new_xyz_desc->getDimIndex(0) == idx_desc->getDimIndex(0));
PARAM_CHECK("[mluOpBallQuery]",
new_xyz_desc->getDimIndex(0) == xyz_desc->getDimIndex(0));
PARAM_CHECK("[mluOpBallQuery]",
new_xyz_desc->getDimIndex(0) == idx_desc->getDimIndex(0));

// check dim1
PARAM_CHECK("[mluOpBallQuery]", new_xyz_desc->getDimIndex(1) == idx_desc->getDimIndex(1));
PARAM_CHECK("[mluOpBallQuery]",
new_xyz_desc->getDimIndex(1) == idx_desc->getDimIndex(1));

// check dim2
PARAM_CHECK("[mluOpBallQuery]", new_xyz_desc->getDimIndex(2) == 3);
Expand All @@ -115,7 +118,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpBallQuery(
<< mluOpGetNameOfDataType(new_xyz_desc->getDtype()) << ".";
return MLUOP_STATUS_BAD_PARAM;
}
PARAM_CHECK_EQ("[mluOpBallQuery]", new_xyz_desc->getDtype(), xyz_desc->getDtype());
PARAM_CHECK_EQ("[mluOpBallQuery]", new_xyz_desc->getDtype(),
xyz_desc->getDtype());

if (idx_desc->getDtype() != MLUOP_DTYPE_INT32) {
LOG(ERROR) << "[mluOpBallQuery]:Only int32 is supportedin output idx, but "
Expand Down Expand Up @@ -155,7 +159,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpBallQuery(
if (mluOpGetTensorElementNum(new_xyz_desc) == 0) {
VLOG(5) << "[mluOpBallQuery] new_xyz tensor is a zero element tensor. The "
"shape of new_xyz tensor is ["
<< new_xyz_desc->getDimIndex(0) << ", " << new_xyz_desc->getDimIndex(1) << ", "
<< new_xyz_desc->getDimIndex(0) << ", "
<< new_xyz_desc->getDimIndex(1) << ", "
<< new_xyz_desc->getDimIndex(2) << "].";
return MLUOP_STATUS_BAD_PARAM;
}
Expand Down
Loading

0 comments on commit 16eea3d

Please sign in to comment.