Skip to content

Commit

Permalink
support reduce type: int64 (#844)
Browse files Browse the repository at this point in the history
* support reduce type: int64

* fix build

* apply code-format changes

---------

Co-authored-by: curioyang <[email protected]>
  • Loading branch information
curioyang and curioyang authored Mar 22, 2023
1 parent f1cdb8e commit f89ddb0
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/evaluator/ops/neutral/neutral_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ void register_neutral_evaluators()
output.buffer().as_span<int32_t>().data(), input.shape(), to(rnode.axis()), input.strides(), output.strides(), rnode.keep_dims())
.unwrap_or_throw();
break;
case dt_int64:
kernels::reduce(rnode.reduce_op(), static_cast<int64_t>(rnode.init_value()), input.buffer().as_span<int64_t>().data(),
output.buffer().as_span<int64_t>().data(), input.shape(), to(rnode.axis()), input.strides(), output.strides(), rnode.keep_dims())
.unwrap_or_throw();
break;
default:
std::cerr << "unsupported dtype for reduce: " + std::string(datatype_names(input_type));
} });
Expand Down
3 changes: 3 additions & 0 deletions src/kernels/cpu/optimized/riscv64/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ result<void> optimized::reduce<float>(reduce_op_t op, float init_value, const fl
template result<void> optimized::reduce<int32_t>(reduce_op_t op, int32_t init_value, const int32_t *input, int32_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template result<void> optimized::reduce<int64_t>(reduce_op_t op, int64_t init_value, const int64_t *input, int64_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template <typename T>
result<void> optimized::reduce(reduce_op_t op, T init_value, const T *input, T *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept
Expand Down
3 changes: 3 additions & 0 deletions src/kernels/cpu/optimized/x86_64/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ template result<void> optimized::reduce<float>(reduce_op_t op, float init_value,
template result<void> optimized::reduce<int32_t>(reduce_op_t op, int32_t init_value, const int32_t *input, int32_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template result<void> optimized::reduce<int64_t>(reduce_op_t op, int64_t init_value, const int64_t *input, int64_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template <typename T>
result<void> optimized::reduce(reduce_op_t op, T init_value, const T *input, T *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept
Expand Down
3 changes: 3 additions & 0 deletions src/kernels/cpu/reference/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ template result<void> reference::reduce<float>(reduce_op_t op, float init_value,
template result<void> reference::reduce<int32_t>(reduce_op_t op, int32_t init_value, const int32_t *input, int32_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template result<void> reference::reduce<int64_t>(reduce_op_t op, int64_t init_value, const int64_t *input, int64_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template <typename T>
result<void> reference::reduce(reduce_op_t op, T init_value, const T *input, T *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept
Expand Down
3 changes: 3 additions & 0 deletions src/kernels/tensor_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ template result<void> kernels::reduce<float>(reduce_op_t op, float init_value, c
template result<void> kernels::reduce<int32_t>(reduce_op_t op, int32_t init_value, const int32_t *input, int32_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template result<void> kernels::reduce<int64_t>(reduce_op_t op, int64_t init_value, const int64_t *input, int64_t *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template <typename T>
result<void> kernels::reduce(reduce_op_t op, T init_value, const T *input, T *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept
Expand Down
1 change: 1 addition & 0 deletions src/runtime/stackvm/evaluate_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class stack_entry
int8_t as_i1() const noexcept { return (int8_t)i_; }
int16_t as_i2() const noexcept { return (int16_t)i_; }
int32_t as_i4() const noexcept { return (int32_t)i_; }
int64_t as_i8() const noexcept { return (int64_t)i_; }
uintptr_t as_u() const noexcept { return (uintptr_t)i_; }
intptr_t as_i() const noexcept { return i_; }

Expand Down
4 changes: 4 additions & 0 deletions src/runtime/stackvm/ops/tensor.reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ result<void> stackvm_runtime_function::visit(const tensor_reduce_op_t &op) noexc
return kernels::reduce(op.reduce_op, init_value.as_i4(), reinterpret_cast<const int32_t *>(input),
reinterpret_cast<int32_t *>(output), in_shape, axis, in_strides, out_strides, op.keep_dims, module().kernel_context());
break;
case dt_int64:
return kernels::reduce(op.reduce_op, init_value.as_i8(), reinterpret_cast<const int64_t *>(input),
reinterpret_cast<int64_t *>(output), in_shape, axis, in_strides, out_strides, op.keep_dims, module().kernel_context());
break;
default:
std::cerr << "unsupported dtype for reduce: " + std::string(datatype_names(op.datatype)) << std::endl;
return err(std::errc::invalid_argument);
Expand Down

0 comments on commit f89ddb0

Please sign in to comment.