Skip to content

Commit

Permalink
Merge branch 'master' into release/1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Jul 1, 2022
2 parents 90d36bb + bdcd341 commit 3a3a658
Show file tree
Hide file tree
Showing 30 changed files with 1,530 additions and 195 deletions.
1 change: 1 addition & 0 deletions docs/onnx_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
| ReverseSequence ||
| RoiAlign ||
| Round ||
| Rsqrt ||
| Selu ||
| Shape ||
| Sign ||
Expand Down
55 changes: 40 additions & 15 deletions include/nncase/codegen/stackvm/op_writer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 5/31/2022 2:39:39 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 6/30/2022 4:30:43 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1087,19 +1087,6 @@ struct op_writer<nncase::runtime::stackvm::tensor_gather_nd_op_t>
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_gru_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_gru_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(op.input_shape_src);
writer.write(op.w_shape_src);
writer.write(op.direction);
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_hardmax_op_t>
{
Expand Down Expand Up @@ -1462,6 +1449,43 @@ struct op_writer<nncase::runtime::stackvm::tensor_transpose_op_t>
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_gru_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_gru_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(op.input_shape_src);
writer.write(op.w_shape_src);
writer.write(op.direction);
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_tflite_detection_postprocess_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_tflite_detection_postprocess_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(op.box_shape_src);
writer.write(op.score_shape_src);
writer.write(op.anchor_shape_src);
writer.write(op.max_detections);
writer.write(op.max_classes_per_detection);
writer.write(op.detections_per_class);
writer.write(op.use_regular_non_max_suppression);
writer.write(op.nms_score_threshold);
writer.write(op.nms_iou_threshold);
writer.write(op.num_classes);
writer.write(op.y_scale);
writer.write(op.x_scale);
writer.write(op.h_scale);
writer.write(op.w_scale);
}
};

class NNCASE_API op_builder
{
public:
Expand Down Expand Up @@ -1575,7 +1599,6 @@ class NNCASE_API op_builder
void tensor_dequantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_gather_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t axis);
void tensor_gather_nd_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t batch_dims);
void tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction);
void tensor_hardmax_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, int32_t axis);
void tensor_lut1d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t table_len);
void tensor_matmul_(uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, float fused_clamp_low, float fused_clamp_high);
Expand All @@ -1598,6 +1621,8 @@ class NNCASE_API op_builder
void tensor_trilu_(datatype_t datatype, uint8_t rshape_src, bool upper, int64_t k);
void tensor_unary_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, unary_op_t unary_op);
void tensor_transpose_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_perm);
void tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction);
void tensor_tflite_detection_postprocess_(uint8_t box_shape_src, uint8_t score_shape_src, uint8_t anchor_shape_src, int32_t max_detections, int32_t max_classes_per_detection, int32_t detections_per_class, bool use_regular_non_max_suppression, float nms_score_threshold, float nms_iou_threshold, int32_t num_classes, float y_scale, float x_scale, float h_scale, float w_scale);

private:
section_writer &writer_;
Expand Down
1 change: 1 addition & 0 deletions include/nncase/ir/opcode.def
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ DEFINE_NEUTRAL_OPCODE(roi_align, RoiAlign, 0x126)
DEFINE_NEUTRAL_OPCODE(compare, Compare, 0x127)
DEFINE_NEUTRAL_OPCODE(softmax, Softmax, 0x128)
DEFINE_NEUTRAL_OPCODE(gru, GRU, 0x129)
DEFINE_NEUTRAL_OPCODE(tflite_detection_postprocess, TfliteDetectionPostprocess, 0x12A)
74 changes: 74 additions & 0 deletions include/nncase/ir/ops/tflite_detection_postprocess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../node.h"
#include <xtensor/xtensor.hpp>

namespace nncase::ir
{
class NNCASE_API tflite_detection_postprocess : public node
{
public:
DEFINE_NODE_OPCODE(op_tflite_detection_postprocess);

input_connector &boxes() { return input_at(0); }
input_connector &scores() { return input_at(1); }
input_connector &anchors() { return input_at(2); }
output_connector &output_locations() { return output_at(0); }
output_connector &output_classes() { return output_at(1); }
output_connector &output_scores() { return output_at(2); }
output_connector &output_num_detections() { return output_at(3); }

int32_t max_detections() const noexcept { return max_detections_; }
int32_t max_classes_per_detection() const noexcept { return max_classes_per_detection_; }
int32_t detections_per_class() const noexcept { return detections_per_class_; }
bool use_regular_non_max_suppression() const noexcept { return use_regular_non_max_suppression_; }
float nms_score_threshold() const noexcept { return nms_score_threshold_; }
float nms_iou_threshold() const noexcept { return nms_iou_threshold_; };
int32_t num_classes() const noexcept { return num_classes_; };
float y_scale() const noexcept { return y_scale_; };
float x_scale() const noexcept { return x_scale_; };
float h_scale() const noexcept { return h_scale_; };
float w_scale() const noexcept { return w_scale_; };

tflite_detection_postprocess(
shape_t boxes_shape, shape_t scores_shape, shape_t anchors_shape,
shape_t output_shape_0, shape_t output_shape_1, shape_t output_shape_2, shape_t output_shape_3,
int32_t max_detections,
int32_t max_classes_per_detection,
int32_t detections_per_class,
bool use_regular_non_max_suppression,
float nms_score_threshold,
float nms_iou_threshold,
int32_t num_classes,
float y_scale, float x_scale, float h_scale, float w_scale);

protected:
bool properties_equal(node &other) const override;

private:
int32_t max_detections_;
int32_t max_classes_per_detection_;
int32_t detections_per_class_;
bool use_regular_non_max_suppression_;
float nms_score_threshold_;
float nms_iou_threshold_;
int32_t num_classes_;
float y_scale_;
float x_scale_;
float h_scale_;
float w_scale_;
};
}
7 changes: 7 additions & 0 deletions include/nncase/kernels/cpu/reference/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,11 @@ NNCASE_API result<void> trilu(const T *input, T *output, const runtime_shape_t &
template <typename T>
NNCASE_API result<void> gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept;

template <typename T>
NNCASE_API result<void> tflite_detection_postprocess(const T *boxes, const T *scores, const T *anchors, T *output_locations, T *output_classes, T *output_scores, T *output_num_detections,
const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, const runtime_shape_t &anchors_shape,
const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class,
const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold,
const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) noexcept;

END_NS_NNCASE_KERNELS_CPU_REF
Loading

0 comments on commit 3a3a658

Please sign in to comment.