From bdcd341d6886c508249ae247d4c9fe4523de388e Mon Sep 17 00:00:00 2001 From: Curio Yang Date: Thu, 30 Jun 2022 16:48:10 +0800 Subject: [PATCH] Feature/detection postprocess (#615) * add tflite_detection_postprocess * add tflite_detection_postprocess stackvm op * delete unused * rename output * delete unused * fix bug * add onnx op doc * apply code-format changes * modify func index * modify func index * apply code-format changes Co-authored-by: curioyang --- docs/onnx_ops.md | 1 + include/nncase/codegen/stackvm/op_writer.h | 55 ++- include/nncase/ir/opcode.def | 1 + .../ir/ops/tflite_detection_postprocess.h | 74 ++++ .../kernels/cpu/reference/tensor_compute.h | 7 + .../nncase/kernels/neutral/neutral_kernels.h | 346 +++++++++++++++- include/nncase/kernels/tensor_compute.h | 7 + include/nncase/runtime/stackvm/op_reader.h | 61 ++- include/nncase/runtime/stackvm/opcode.h | 111 ++++-- .../transforms/neutral/fix_output_shape.h | 28 ++ src/codegen/stackvm/CMakeLists.txt | 75 ++-- src/codegen/stackvm/module_builder.h | 1 + src/codegen/stackvm/op_writer.cpp | 17 +- src/codegen/stackvm/ops.def | 1 + .../ops/tflite_detection_postprocess.cpp | 47 +++ src/evaluator/ops/neutral/neutral_ops.cpp | 17 + src/importer/tflite/ops/custom.cpp | 48 +++ src/ir/ops/CMakeLists.txt | 3 +- src/ir/ops/tflite_detection_postprocess.cpp | 57 +++ src/kernels/cpu/reference/CMakeLists.txt | 1 + .../tflite_detection_postprocess.cpp | 376 ++++++++++++++++++ src/kernels/tensor_compute.cpp | 20 + src/runtime/stackvm/CMakeLists.txt | 93 ++--- src/runtime/stackvm/op_reader.cpp | 23 +- .../tensor.tflite_detection_postprocess.cpp | 44 ++ src/runtime/stackvm/runtime_function.h | 1 + src/targets/neutral_target.cpp | 9 + src/transforms/neutral/CMakeLists.txt | 1 + src/transforms/neutral/fix_output_shape.cpp | 93 +++++ tools/stackvm_gen/IsaGen/Instructions.cs | 107 ++++- 30 files changed, 1530 insertions(+), 195 deletions(-) create mode 100644 include/nncase/ir/ops/tflite_detection_postprocess.h create mode 100644 include/nncase/transforms/neutral/fix_output_shape.h create mode 100644 src/codegen/stackvm/ops/tflite_detection_postprocess.cpp create mode 100644 src/ir/ops/tflite_detection_postprocess.cpp create mode 100644 src/kernels/cpu/reference/tflite_detection_postprocess.cpp create mode 100644 src/runtime/stackvm/ops/tensor.tflite_detection_postprocess.cpp create mode 100644 src/transforms/neutral/fix_output_shape.cpp diff --git a/docs/onnx_ops.md b/docs/onnx_ops.md index 5616760c97..f1c9d305f6 100644 --- a/docs/onnx_ops.md +++ b/docs/onnx_ops.md @@ -90,6 +90,7 @@ | ReverseSequence | ✅ | | RoiAlign | ✅ | | Round | ✅ | +| Rsqrt | ✅ | | Selu | ✅ | | Shape | ✅ | | Sign | ✅ | diff --git a/include/nncase/codegen/stackvm/op_writer.h b/include/nncase/codegen/stackvm/op_writer.h index e8d3464199..3481999457 100644 --- a/include/nncase/codegen/stackvm/op_writer.h +++ b/include/nncase/codegen/stackvm/op_writer.h @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 5/31/2022 2:39:39 PM +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 6/30/2022 4:30:43 PM +08:00. * * Copyright 2019-2021 Canaan Inc. * @@ -1087,19 +1087,6 @@ struct op_writer } }; -template <> -struct op_writer -{ - void operator()(const nncase::runtime::stackvm::tensor_gru_op_t &op, binary_writer &writer) const - { - writer.write(static_cast(op.opcode)); - writer.write(static_cast(op.funct)); - writer.write(op.input_shape_src); - writer.write(op.w_shape_src); - writer.write(op.direction); - } -}; - template <> struct op_writer { @@ -1462,6 +1449,43 @@ struct op_writer } }; +template <> +struct op_writer +{ + void operator()(const nncase::runtime::stackvm::tensor_gru_op_t &op, binary_writer &writer) const + { + writer.write(static_cast(op.opcode)); + writer.write(static_cast(op.funct)); + writer.write(op.input_shape_src); + writer.write(op.w_shape_src); + writer.write(op.direction); + } +}; + +template <> +struct op_writer +{ + void operator()(const nncase::runtime::stackvm::tensor_tflite_detection_postprocess_op_t &op, binary_writer &writer) const + { + writer.write(static_cast(op.opcode)); + writer.write(static_cast(op.funct)); + writer.write(op.box_shape_src); + writer.write(op.score_shape_src); + writer.write(op.anchor_shape_src); + writer.write(op.max_detections); + writer.write(op.max_classes_per_detection); + writer.write(op.detections_per_class); + writer.write(op.use_regular_non_max_suppression); + writer.write(op.nms_score_threshold); + writer.write(op.nms_iou_threshold); + writer.write(op.num_classes); + writer.write(op.y_scale); + writer.write(op.x_scale); + writer.write(op.h_scale); + writer.write(op.w_scale); + } +}; + class NNCASE_API op_builder { public: @@ -1575,7 +1599,6 @@ class NNCASE_API op_builder void tensor_dequantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest); void tensor_gather_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t axis); void tensor_gather_nd_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t batch_dims); - void tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction); void tensor_hardmax_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, int32_t axis); void tensor_lut1d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t table_len); void tensor_matmul_(uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, float fused_clamp_low, float fused_clamp_high); @@ -1598,6 +1621,8 @@ class NNCASE_API op_builder void tensor_trilu_(datatype_t datatype, uint8_t rshape_src, bool upper, int64_t k); void tensor_unary_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, unary_op_t unary_op); void tensor_transpose_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_perm); + void tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction); + void tensor_tflite_detection_postprocess_(uint8_t box_shape_src, uint8_t score_shape_src, uint8_t anchor_shape_src, int32_t max_detections, int32_t max_classes_per_detection, int32_t detections_per_class, bool use_regular_non_max_suppression, float nms_score_threshold, float nms_iou_threshold, int32_t num_classes, float y_scale, float x_scale, float h_scale, float w_scale); private: section_writer &writer_; diff --git a/include/nncase/ir/opcode.def b/include/nncase/ir/opcode.def index 64aeaccdb0..1686dd8e03 100644 --- a/include/nncase/ir/opcode.def +++ b/include/nncase/ir/opcode.def @@ -48,3 +48,4 @@ DEFINE_NEUTRAL_OPCODE(roi_align, RoiAlign, 0x126) DEFINE_NEUTRAL_OPCODE(compare, Compare, 0x127) DEFINE_NEUTRAL_OPCODE(softmax, Softmax, 0x128) DEFINE_NEUTRAL_OPCODE(gru, GRU, 0x129) +DEFINE_NEUTRAL_OPCODE(tflite_detection_postprocess, TfliteDetectionPostprocess, 0x12A) diff --git a/include/nncase/ir/ops/tflite_detection_postprocess.h b/include/nncase/ir/ops/tflite_detection_postprocess.h new file mode 100644 index 0000000000..1d88c3abfa --- /dev/null +++ b/include/nncase/ir/ops/tflite_detection_postprocess.h @@ -0,0 +1,74 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../node.h" +#include + +namespace nncase::ir +{ +class NNCASE_API tflite_detection_postprocess : public node +{ +public: + DEFINE_NODE_OPCODE(op_tflite_detection_postprocess); + + input_connector &boxes() { return input_at(0); } + input_connector &scores() { return input_at(1); } + input_connector &anchors() { return input_at(2); } + output_connector &output_locations() { return output_at(0); } + output_connector &output_classes() { return output_at(1); } + output_connector &output_scores() { return output_at(2); } + output_connector &output_num_detections() { return output_at(3); } + + int32_t max_detections() const noexcept { return max_detections_; } + int32_t max_classes_per_detection() const noexcept { return max_classes_per_detection_; } + int32_t detections_per_class() const noexcept { return detections_per_class_; } + bool use_regular_non_max_suppression() const noexcept { return use_regular_non_max_suppression_; } + float nms_score_threshold() const noexcept { return nms_score_threshold_; } + float nms_iou_threshold() const noexcept { return nms_iou_threshold_; }; + int32_t num_classes() const noexcept { return num_classes_; }; + float y_scale() const noexcept { return y_scale_; }; + float x_scale() const noexcept { return x_scale_; }; + float h_scale() const noexcept { return h_scale_; }; + float w_scale() const noexcept { return w_scale_; }; + + tflite_detection_postprocess( + shape_t boxes_shape, shape_t scores_shape, shape_t anchors_shape, + shape_t output_shape_0, shape_t output_shape_1, shape_t output_shape_2, shape_t output_shape_3, + int32_t max_detections, + int32_t max_classes_per_detection, + int32_t detections_per_class, + bool use_regular_non_max_suppression, + float nms_score_threshold, + float nms_iou_threshold, + int32_t num_classes, + float y_scale, float x_scale, float h_scale, float w_scale); + +protected: + bool properties_equal(node &other) const override; + +private: + int32_t max_detections_; + int32_t max_classes_per_detection_; + int32_t detections_per_class_; + bool use_regular_non_max_suppression_; + float nms_score_threshold_; + float nms_iou_threshold_; + int32_t num_classes_; + float y_scale_; + float x_scale_; + float h_scale_; + float w_scale_; +}; +} diff --git a/include/nncase/kernels/cpu/reference/tensor_compute.h b/include/nncase/kernels/cpu/reference/tensor_compute.h index 2adeb7924d..0f3d15dee0 100644 --- a/include/nncase/kernels/cpu/reference/tensor_compute.h +++ b/include/nncase/kernels/cpu/reference/tensor_compute.h @@ -152,4 +152,11 @@ NNCASE_API result trilu(const T *input, T *output, const runtime_shape_t & template NNCASE_API result gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept; +template +NNCASE_API result tflite_detection_postprocess(const T *boxes, const T *scores, const T *anchors, T *output_locations, T *output_classes, T *output_scores, T *output_num_detections, + const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, const runtime_shape_t &anchors_shape, + const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class, + const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold, + const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) noexcept; + END_NS_NNCASE_KERNELS_CPU_REF diff --git a/include/nncase/kernels/neutral/neutral_kernels.h b/include/nncase/kernels/neutral/neutral_kernels.h index 95db642e0e..57f1f8f0f2 100644 --- a/include/nncase/kernels/neutral/neutral_kernels.h +++ b/include/nncase/kernels/neutral/neutral_kernels.h @@ -833,7 +833,6 @@ void gru(const T *CXX_RESTRICT input, const T *CXX_RESTRICT w, const T *CXX_REST auto tanh = [&](float x) { return std::tanh(x); }; - // copy input to output runtime_shape_t out_shape { (size_t)seq_length, (size_t)num_direction, (size_t)batch_size, (size_t)hidden_size }; auto x_gate_size = batch_size * input_size; @@ -968,4 +967,349 @@ void gru(const T *CXX_RESTRICT input, const T *CXX_RESTRICT w, const T *CXX_REST } } +template +void tflite_detection_postprocess(const T *CXX_RESTRICT boxes, const T *CXX_RESTRICT scores, const T *CXX_RESTRICT anchors, T *CXX_RESTRICT output_locations, T *CXX_RESTRICT output_classes, T *CXX_RESTRICT output_scores, T *CXX_RESTRICT output_num_detections, + const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, NNCASE_UNUSED const runtime_shape_t &anchors_shape, + const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class, + const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold, + const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) +{ + struct CenterSizeEncoding + { + float y; + float x; + float h; + float w; + }; + struct BoxCornerEncoding + { + float ymin; + float xmin; + float ymax; + float xmax; + }; + struct BoxInfo + { + int index; + float score; + }; + + auto compute_iou = [&](const std::vector &box, const int &i, const int &j) { + auto &box_i = box[i]; + auto &box_j = box[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) + return 0.f; + const float intersection_y_min = std::max(box_i.ymin, box_j.ymin); + const float intersection_x_min = std::max(box_i.xmin, box_j.xmin); + const float intersection_y_max = std::min(box_i.ymax, box_j.ymax); + const float intersection_x_max = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = std::max(intersection_y_max - intersection_y_min, 0.0) * std::max(intersection_x_max - intersection_x_min, 0.0); + return intersection_area / (area_i + area_j - intersection_area); + }; + + const auto num_boxes = (int)anchors_shape[0]; + const auto num_classes_with_background = (int)scores_shape[2]; // num_classes + background + const auto num_detections_per_class = std::min(detections_per_class, max_detections); + int label_offset = num_classes_with_background - num_classes; + // DecodeCenterSizeBoxes: get decoded_boxes + std::vector decoded_boxes(boxes_shape[1]); + { + + CenterSizeEncoding box_center_size; + CenterSizeEncoding scale_values { y_scale, x_scale, h_scale, w_scale }; + CenterSizeEncoding anchor; + + for (int index = 0; index < num_boxes; index++) + { + const auto box_encoding_index = index * boxes_shape[2]; + box_center_size = *reinterpret_cast(boxes + box_encoding_index); + anchor = *reinterpret_cast(anchors + box_encoding_index); + + auto y_center = static_cast(static_cast(box_center_size.y) / static_cast(scale_values.y) * static_cast(anchor.h) + static_cast(anchor.y)); + auto x_center = static_cast(static_cast(box_center_size.x) / static_cast(scale_values.x) * static_cast(anchor.w) + static_cast(anchor.x)); + auto half_h = static_cast(0.5 * (std::exp(static_cast(box_center_size.h) / static_cast(scale_values.h))) * static_cast(anchor.h)); + auto half_w = static_cast(0.5 * (std::exp(static_cast(box_center_size.w) / static_cast(scale_values.w))) * static_cast(anchor.w)); + decoded_boxes[index].ymin = y_center - half_h; + decoded_boxes[index].xmin = x_center - half_w; + decoded_boxes[index].ymax = y_center + half_h; + decoded_boxes[index].xmax = x_center + half_w; + } + } + // NMS MultiClass + { + if (use_regular_non_max_suppression) + { + // NMS Regular + int sorted_indices_size = 0; + std::vector box_info_after_regular_nms(max_detections + num_detections_per_class); + std::vector num_selected(num_classes); + + // compute nms + std::vector class_scores(num_boxes); + std::vector selected; + selected.reserve(num_detections_per_class); + + for (auto col = 0; col < num_classes - 1; col++) + { + const float *scores_base = scores + col + label_offset; + for (int row = 0; row < num_boxes; row++) + { + // Get scores of boxes corresponding to all anchors for single class + class_scores[row] = *scores_base; + scores_base += num_classes_with_background; + } + // Perform non-maximal suppression on single class + selected.clear(); + + // NMS SingleClass + { + std::vector keep_indices; + std::vector keep_scores; + // select detection box score above score threshold + { + for (size_t i = 0; i < class_scores.size(); i++) + { + if (class_scores[i] >= nms_score_threshold) + { + keep_scores.emplace_back(class_scores[i]); + keep_indices.emplace_back(i); + } + } + } + + int num_scores_kept = (int)keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + // DecreasingArgSort + { + std::iota(sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, 0); + std::stable_sort( + sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, + [&keep_scores](const int i, const int j) { return keep_scores[i] > keep_scores[j]; }); + } + + const int output_size = std::min(num_scores_kept, max_detections); + selected.clear(); + int num_active_candidate = num_scores_kept; + std::vector active_box_candidate(num_scores_kept, 1); + for (int i = 0; i < num_scores_kept; ++i) + { + if (num_active_candidate == 0 || (int)selected.size() >= output_size) + break; + if (active_box_candidate[i] == 1) + { + selected.push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } + else + { + continue; + } + for (int j = i + 1; j < num_scores_kept; ++j) + { + if (active_box_candidate[j] == 1) + { + + float iou = compute_iou( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (iou > nms_iou_threshold) + { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + } + // end NMS SingleClass + + if (selected.empty()) + { + continue; + } + for (size_t i = 0; i < selected.size(); ++i) + { + box_info_after_regular_nms[sorted_indices_size + i].score = class_scores[selected[i]]; + box_info_after_regular_nms[sorted_indices_size + i].index = (selected[i] * num_classes_with_background + col + label_offset); + } + + // In-place merge the original boxes and new selected boxes which are both + // sorted by scores. + std::inplace_merge(box_info_after_regular_nms.begin(), box_info_after_regular_nms.begin() + sorted_indices_size, + box_info_after_regular_nms.begin() + sorted_indices_size + selected.size(), + [](const BoxInfo &a, const BoxInfo &b) { return a.score >= b.score; }); + + sorted_indices_size = std::min(sorted_indices_size + static_cast(selected.size()), max_detections); + } + // end compute nms result + + // Allocate output tensors + for (int output_box_index = 0; output_box_index < max_detections; output_box_index++) + { + if (output_box_index < sorted_indices_size) + { + const int anchor_index = floor( + box_info_after_regular_nms[output_box_index].index / num_classes_with_background); + const int class_index = box_info_after_regular_nms[output_box_index].index - anchor_index * num_classes_with_background - label_offset; + const float selected_score = box_info_after_regular_nms[output_box_index].score; + // detection_boxes + reinterpret_cast(output_locations)[output_box_index] = decoded_boxes[anchor_index]; + // detection_classes + output_classes[output_box_index] = class_index; + // detection_scores + output_scores[output_box_index] = selected_score; + } + else + { + // detection_boxes + reinterpret_cast(output_locations)[output_box_index] = { 0.0f, 0.0f, 0.0f, 0.0f }; + // detection_classes + output_classes[output_box_index] = 0.0f; + // detection_scores + output_scores[output_box_index] = 0.0f; + } + } + output_num_detections[0] = sorted_indices_size; + box_info_after_regular_nms.clear(); + } + else + { + // Fast NMS + + const int max_categories_per_anchor = max_classes_per_detection; + const int num_categories_per_anchor = std::min(max_categories_per_anchor, num_classes); + + std::vector max_scores; + max_scores.resize(num_boxes); + std::vector sorted_class_indices; + sorted_class_indices.resize(num_boxes * num_categories_per_anchor); + + for (int row = 0; row < num_boxes; row++) + { + const float *box_scores = scores + row * num_classes_with_background + label_offset; + int *class_indices = sorted_class_indices.data() + row * num_categories_per_anchor; + + // DecreasingPartialArgSort + if (num_categories_per_anchor == 1) + { + auto arg_max_vector = [&](const T *input_data, int size) { + T max_value = input_data[0]; + int max_index = 0; + for (int i = 1; i < size; ++i) + { + // const T curr_value = input_data[i]; + if (input_data[i] > max_value) + { + max_value = input_data[i]; + max_index = i; + } + } + return max_index; + }; + class_indices[0] = arg_max_vector(box_scores, num_classes); + } + else + { + std::iota(class_indices, class_indices + num_classes, 0); + std::partial_sort( + class_indices, class_indices + num_categories_per_anchor, class_indices + num_classes, + [&box_scores](const int i, const int j) { return box_scores[i] > box_scores[j]; }); + } + // end DecreasingPartialArgSort + + max_scores[row] = box_scores[class_indices[0]]; + } + std::vector selected; + // NMS SingleClass + { + std::vector keep_indices; + std::vector keep_scores; + // select detection box score above score threshold + { + for (size_t i = 0; i < max_scores.size(); i++) + { + if (max_scores[i] >= nms_score_threshold) + { + keep_scores.emplace_back(max_scores[i]); + keep_indices.emplace_back(i); + } + } + } + + int num_scores_kept = (int)keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + // DecreasingArgSort + { + std::iota(sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, 0); + std::stable_sort( + sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, + [&keep_scores](const int i, const int j) { return keep_scores[i] > keep_scores[j]; }); + } + const int output_size = std::min(num_scores_kept, max_detections); + selected.clear(); + int num_active_candidate = num_scores_kept; + std::vector active_box_candidate(num_scores_kept, 1); + for (int i = 0; i < num_scores_kept; ++i) + { + if (num_active_candidate == 0 || (int)selected.size() >= output_size) + break; + if (active_box_candidate[i] == 1) + { + selected.push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } + else + { + continue; + } + for (int j = i + 1; j < num_scores_kept; ++j) + { + if (active_box_candidate[j] == 1) + { + + float iou = compute_iou( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + if (iou > nms_iou_threshold) + { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + } + // end NMS SingleClass + + // Allocate output tensors + int output_box_index = 0; + for (const auto &selected_index : selected) + { + const float *box_scores = scores + selected_index * num_classes_with_background + label_offset; + const int *class_indices = sorted_class_indices.data() + selected_index * num_categories_per_anchor; + + for (int col = 0; col < num_categories_per_anchor; ++col) + { + int box_offset = max_categories_per_anchor * output_box_index + col; + // detection_boxes + reinterpret_cast(output_locations)[box_offset] = decoded_boxes[selected_index]; + // detection_classes + output_classes[box_offset] = class_indices[col]; + // detection_scores + output_scores[box_offset] = box_scores[class_indices[col]]; + } + output_box_index++; + } + output_num_detections[0] = output_box_index; + } + } +} + } diff --git a/include/nncase/kernels/tensor_compute.h b/include/nncase/kernels/tensor_compute.h index de1e626d01..9a98c67160 100644 --- a/include/nncase/kernels/tensor_compute.h +++ b/include/nncase/kernels/tensor_compute.h @@ -155,4 +155,11 @@ NNCASE_API result trilu(const T *input, T *output, const runtime_shape_t & template NNCASE_API result gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept; +template +NNCASE_API result tflite_detection_postprocess(const T *boxes, const T *scores, const T *anchors, T *output_locations, T *output_classes, T *output_scores, T *output_num_detections, + const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, const runtime_shape_t &anchors_shape, + const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class, + const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold, + const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) noexcept; + END_NS_NNCASE_KERNELS diff --git a/include/nncase/runtime/stackvm/op_reader.h b/include/nncase/runtime/stackvm/op_reader.h index 84e9c20b80..41ad79269a 100644 --- a/include/nncase/runtime/stackvm/op_reader.h +++ b/include/nncase/runtime/stackvm/op_reader.h @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 5/31/2022 2:39:39 PM +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 6/30/2022 4:30:43 PM +08:00. * * Copyright 2019-2021 Canaan Inc. * @@ -1301,21 +1301,6 @@ struct op_reader } }; -template <> -struct op_reader -{ - tensor_gru_op_t operator()(span_reader &reader) const - { - tensor_gru_op_t op(default_init); - op.opcode = static_cast(reader.read_unaligned()); - op.funct = static_cast(reader.read_unaligned()); - op.input_shape_src = reader.read_unaligned(); - op.w_shape_src = reader.read_unaligned(); - op.direction = reader.read_unaligned(); - return op; - } -}; - template <> struct op_reader { @@ -1722,6 +1707,47 @@ struct op_reader } }; +template <> +struct op_reader +{ + tensor_gru_op_t operator()(span_reader &reader) const + { + tensor_gru_op_t op(default_init); + op.opcode = static_cast(reader.read_unaligned()); + op.funct = static_cast(reader.read_unaligned()); + op.input_shape_src = reader.read_unaligned(); + op.w_shape_src = reader.read_unaligned(); + op.direction = reader.read_unaligned(); + return op; + } +}; + +template <> +struct op_reader +{ + tensor_tflite_detection_postprocess_op_t operator()(span_reader &reader) const + { + tensor_tflite_detection_postprocess_op_t op(default_init); + op.opcode = static_cast(reader.read_unaligned()); + op.funct = static_cast(reader.read_unaligned()); + op.box_shape_src = reader.read_unaligned(); + op.score_shape_src = reader.read_unaligned(); + op.anchor_shape_src = reader.read_unaligned(); + op.max_detections = reader.read_unaligned(); + op.max_classes_per_detection = reader.read_unaligned(); + op.detections_per_class = reader.read_unaligned(); + op.use_regular_non_max_suppression = reader.read_unaligned(); + op.nms_score_threshold = reader.read_unaligned(); + op.nms_iou_threshold = reader.read_unaligned(); + op.num_classes = reader.read_unaligned(); + op.y_scale = reader.read_unaligned(); + op.x_scale = reader.read_unaligned(); + op.h_scale = reader.read_unaligned(); + op.w_scale = reader.read_unaligned(); + return op; + } +}; + class NNCASE_API op_visitor { public: @@ -1840,7 +1866,6 @@ class NNCASE_API op_visitor virtual result visit(NNCASE_UNUSED const tensor_dequantize_op_t &op) noexcept { return ok(); } virtual result visit(NNCASE_UNUSED const tensor_gather_op_t &op) noexcept { return ok(); } virtual result visit(NNCASE_UNUSED const tensor_gather_nd_op_t &op) noexcept { return ok(); } - virtual result visit(NNCASE_UNUSED const tensor_gru_op_t &op) noexcept { return ok(); } virtual result visit(NNCASE_UNUSED const tensor_hardmax_op_t &op) noexcept { return ok(); } virtual result visit(NNCASE_UNUSED const tensor_lut1d_op_t &op) noexcept { return ok(); } virtual result visit(NNCASE_UNUSED const tensor_matmul_op_t &op) noexcept { return ok(); } @@ -1863,6 +1888,8 @@ class NNCASE_API op_visitor virtual result visit(NNCASE_UNUSED const tensor_trilu_op_t &op) noexcept { return ok(); } virtual result visit(NNCASE_UNUSED const tensor_unary_op_t &op) noexcept { return ok(); } virtual result visit(NNCASE_UNUSED const tensor_transpose_op_t &op) noexcept { return ok(); } + virtual result visit(NNCASE_UNUSED const tensor_gru_op_t &op) noexcept { return ok(); } + virtual result visit(NNCASE_UNUSED const tensor_tflite_detection_postprocess_op_t &op) noexcept { return ok(); } protected: bool interrupted_; diff --git a/include/nncase/runtime/stackvm/opcode.h b/include/nncase/runtime/stackvm/opcode.h index 4215ce9dfd..3fd0109d69 100644 --- a/include/nncase/runtime/stackvm/opcode.h +++ b/include/nncase/runtime/stackvm/opcode.h @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 5/31/2022 2:39:39 PM +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 6/30/2022 4:30:43 PM +08:00. * * Copyright 2019-2021 Canaan Inc. * @@ -136,32 +136,33 @@ enum class tensor_function_t DEQUANTIZE = 0x000B, GATHER = 0x000C, GATHER_ND = 0x000D, - GRU = 0x000E, - HARDMAX = 0x000F, - LOGISTIC = 0x0010, - LUT1D = 0x0011, - MATMUL = 0x0012, - ONEHOT = 0x0013, - PAD = 0x0014, - QUANTIZE = 0x0015, - RANDOM_NORMAL = 0x0016, - RANDOM_UNIFORM = 0x0017, - REDUCE = 0x0018, - REDUCE_ARG = 0x0019, - REDUCE_PROD = 0x001A, - REDUCE_WINDOW2D = 0x001B, - RESIZE_IMAGE = 0x001C, - ROI_ALIGN = 0x001D, - SIGMOID = 0x001E, - SLICE = 0x001F, - SOFTMAX = 0x0020, - SPACE_TO_BATCH = 0x0021, - TAKE = 0x0022, - TERNARY = 0x0023, - TOPK = 0x0024, - TRANSPOSE = 0x0025, - TRILU = 0x0026, - UNARY = 0x0027, + HARDMAX = 0x000E, + LOGISTIC = 0x000F, + LUT1D = 0x0010, + MATMUL = 0x0011, + ONEHOT = 0x0012, + PAD = 0x0013, + QUANTIZE = 0x0014, + RANDOM_NORMAL = 0x0015, + RANDOM_UNIFORM = 0x0016, + REDUCE = 0x0017, + REDUCE_ARG = 0x0018, + REDUCE_PROD = 0x0019, + REDUCE_WINDOW2D = 0x001A, + RESIZE_IMAGE = 0x001B, + ROI_ALIGN = 0x001C, + SIGMOID = 0x001D, + SLICE = 0x001E, + SOFTMAX = 0x001F, + SPACE_TO_BATCH = 0x0020, + TAKE = 0x0021, + TERNARY = 0x0022, + TOPK = 0x0023, + TRANSPOSE = 0x0024, + TRILU = 0x0025, + UNARY = 0x0026, + GRU = 0x0027, + TFLITE_DETECTION_POSTPROCESS = 0x0028, }; // Instructions @@ -1442,21 +1443,6 @@ struct tensor_gather_nd_op_t } }; -struct tensor_gru_op_t -{ - opcode_t opcode; - tensor_function_t funct; - uint8_t input_shape_src; - uint8_t w_shape_src; - uint8_t direction; - - tensor_gru_op_t(default_init_t) noexcept { } - explicit tensor_gru_op_t(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction) noexcept - : opcode(opcode_t::TENSOR), funct(tensor_function_t::GRU), input_shape_src(input_shape_src), w_shape_src(w_shape_src), direction(direction) - { - } -}; - struct tensor_hardmax_op_t { opcode_t opcode; @@ -1863,4 +1849,45 @@ struct tensor_transpose_op_t } }; +struct tensor_gru_op_t +{ + opcode_t opcode; + tensor_function_t funct; + uint8_t input_shape_src; + uint8_t w_shape_src; + uint8_t direction; + + tensor_gru_op_t(default_init_t) noexcept { } + explicit tensor_gru_op_t(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction) noexcept + : opcode(opcode_t::TENSOR), funct(tensor_function_t::GRU), input_shape_src(input_shape_src), w_shape_src(w_shape_src), direction(direction) + { + } +}; + +struct tensor_tflite_detection_postprocess_op_t +{ + opcode_t opcode; + tensor_function_t funct; + uint8_t box_shape_src; + uint8_t score_shape_src; + uint8_t anchor_shape_src; + int32_t max_detections; + int32_t max_classes_per_detection; + int32_t detections_per_class; + bool use_regular_non_max_suppression; + float nms_score_threshold; + float nms_iou_threshold; + int32_t num_classes; + float y_scale; + float x_scale; + float h_scale; + float w_scale; + + tensor_tflite_detection_postprocess_op_t(default_init_t) noexcept { } + explicit tensor_tflite_detection_postprocess_op_t(uint8_t box_shape_src, uint8_t score_shape_src, uint8_t anchor_shape_src, int32_t max_detections, int32_t max_classes_per_detection, int32_t detections_per_class, bool use_regular_non_max_suppression, float nms_score_threshold, float nms_iou_threshold, int32_t num_classes, float y_scale, float x_scale, float h_scale, float w_scale) noexcept + : opcode(opcode_t::TENSOR), funct(tensor_function_t::TFLITE_DETECTION_POSTPROCESS), box_shape_src(box_shape_src), score_shape_src(score_shape_src), anchor_shape_src(anchor_shape_src), max_detections(max_detections), max_classes_per_detection(max_classes_per_detection), detections_per_class(detections_per_class), use_regular_non_max_suppression(use_regular_non_max_suppression), nms_score_threshold(nms_score_threshold), nms_iou_threshold(nms_iou_threshold), num_classes(num_classes), y_scale(y_scale), x_scale(x_scale), h_scale(h_scale), w_scale(w_scale) + { + } +}; + END_NS_NNCASE_RT_MODULE diff --git a/include/nncase/transforms/neutral/fix_output_shape.h b/include/nncase/transforms/neutral/fix_output_shape.h new file mode 100644 index 0000000000..f44e304804 --- /dev/null +++ b/include/nncase/transforms/neutral/fix_output_shape.h @@ -0,0 +1,28 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../transform.h" + +namespace nncase::ir::transforms +{ +class NNCASE_API tflite_detection_postprocess_transform : public transform +{ +public: + void process(transform_context &context) override; + +protected: + bool on_try_match(ir::node &node, transform_context &context) override; +}; +} diff --git a/src/codegen/stackvm/CMakeLists.txt b/src/codegen/stackvm/CMakeLists.txt index a3b83fbc67..805f58230d 100644 --- a/src/codegen/stackvm/CMakeLists.txt +++ b/src/codegen/stackvm/CMakeLists.txt @@ -1,42 +1,43 @@ -cmake_minimum_required (VERSION 3.8) +cmake_minimum_required(VERSION 3.8) set(SRCS module_builder.cpp - op_writer.cpp - ops/batch_to_space.cpp - ops/binary.cpp - ops/broadcast.cpp - ops/call.cpp - ops/compare.cpp - ops/conv2d.cpp - ops/convert.cpp - ops/copy.cpp - ops/cumsum.cpp - ops/dequantize.cpp - ops/gather.cpp - ops/gather_nd.cpp - ops/gru.cpp - ops/hardmax.cpp - ops/matmul.cpp - ops/onehot.cpp - ops/pad.cpp - ops/quantize.cpp - ops/random_normal.cpp - ops/random_uniform.cpp - ops/reduce.cpp - ops/reduce_arg.cpp - ops/reduce_prod.cpp - ops/reduce_window2d.cpp - ops/resize_image.cpp - ops/roi_align.cpp - ops/slice.cpp - ops/sigmoid.cpp - ops/softmax.cpp - ops/table_lookup1d.cpp - ops/ternary.cpp - ops/topk.cpp - ops/transpose.cpp - ops/trilu.cpp - ops/unary.cpp) + op_writer.cpp + ops/batch_to_space.cpp + ops/binary.cpp + ops/broadcast.cpp + ops/call.cpp + ops/compare.cpp + ops/conv2d.cpp + ops/convert.cpp + ops/copy.cpp + ops/cumsum.cpp + ops/dequantize.cpp + ops/gather.cpp + ops/gather_nd.cpp + ops/gru.cpp + ops/hardmax.cpp + ops/matmul.cpp + ops/onehot.cpp + ops/pad.cpp + ops/quantize.cpp + ops/random_normal.cpp + ops/random_uniform.cpp + ops/reduce.cpp + ops/reduce_arg.cpp + ops/reduce_prod.cpp + ops/reduce_window2d.cpp + ops/resize_image.cpp + ops/roi_align.cpp + ops/slice.cpp + ops/sigmoid.cpp + ops/softmax.cpp + ops/table_lookup1d.cpp + ops/ternary.cpp + ops/topk.cpp + ops/transpose.cpp + ops/trilu.cpp + ops/tflite_detection_postprocess.cpp + ops/unary.cpp) add_library(codegen_stackvm OBJECT ${SRCS}) target_link_libraries(codegen_stackvm PUBLIC ir schedule) diff --git a/src/codegen/stackvm/module_builder.h b/src/codegen/stackvm/module_builder.h index b51d24b7e6..f38b14a7ff 100644 --- a/src/codegen/stackvm/module_builder.h +++ b/src/codegen/stackvm/module_builder.h @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include diff --git a/src/codegen/stackvm/op_writer.cpp b/src/codegen/stackvm/op_writer.cpp index 45cb72d90d..317921d57b 100644 --- a/src/codegen/stackvm/op_writer.cpp +++ b/src/codegen/stackvm/op_writer.cpp @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 5/31/2022 2:39:39 PM +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 6/30/2022 4:30:43 PM +08:00. * * Copyright 2019-2021 Canaan Inc. * @@ -558,11 +558,6 @@ void op_builder::tensor_gather_nd_(datatype_t datatype, uint8_t rshape_src, uint op_writer()(tensor_gather_nd_op_t(datatype, rshape_src, rshape_dest, rstride_src, rstride_dest, rshape_indices, batch_dims), writer_); } -void op_builder::tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction) -{ - op_writer()(tensor_gru_op_t(input_shape_src, w_shape_src, direction), writer_); -} - void op_builder::tensor_hardmax_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, int32_t axis) { op_writer()(tensor_hardmax_op_t(datatype, rshape_src, rstride_src, axis), writer_); @@ -672,3 +667,13 @@ void op_builder::tensor_transpose_(datatype_t datatype, uint8_t rshape_src, uint { op_writer()(tensor_transpose_op_t(datatype, rshape_src, rstride_src, rstride_dest, rshape_perm), writer_); } + +void op_builder::tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction) +{ + op_writer()(tensor_gru_op_t(input_shape_src, w_shape_src, direction), writer_); +} + +void op_builder::tensor_tflite_detection_postprocess_(uint8_t box_shape_src, uint8_t score_shape_src, uint8_t anchor_shape_src, int32_t max_detections, int32_t max_classes_per_detection, int32_t detections_per_class, bool use_regular_non_max_suppression, float nms_score_threshold, float nms_iou_threshold, int32_t num_classes, float y_scale, float x_scale, float h_scale, float w_scale) +{ + op_writer()(tensor_tflite_detection_postprocess_op_t(box_shape_src, score_shape_src, anchor_shape_src, max_detections, max_classes_per_detection, detections_per_class, use_regular_non_max_suppression, nms_score_threshold, nms_iou_threshold, num_classes, y_scale, x_scale, h_scale, w_scale), writer_); +} diff --git a/src/codegen/stackvm/ops.def b/src/codegen/stackvm/ops.def index e2046f853f..88ad7bd151 100644 --- a/src/codegen/stackvm/ops.def +++ b/src/codegen/stackvm/ops.def @@ -32,4 +32,5 @@ DEFINE_OP(ternary) DEFINE_OP(topk) DEFINE_OP(transpose) DEFINE_OP(trilu) +DEFINE_OP(tflite_detection_postprocess) DEFINE_OP(unary) \ No newline at end of file diff --git a/src/codegen/stackvm/ops/tflite_detection_postprocess.cpp b/src/codegen/stackvm/ops/tflite_detection_postprocess.cpp new file mode 100644 index 0000000000..5da2d2afb9 --- /dev/null +++ b/src/codegen/stackvm/ops/tflite_detection_postprocess.cpp @@ -0,0 +1,47 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "../module_builder.h" + +using namespace nncase; +using namespace nncase::codegen; +using namespace nncase::codegen::stackvm; +using namespace nncase::ir; + +void stackvm_module_builder::emit(tflite_detection_postprocess &node, stackvm_op_builder &builder) +{ + auto &box = allocation(node.boxes()); + auto &score = allocation(node.scores()); + auto &anchor = allocation(node.anchors()); + auto &output_locations = allocation(node.output_locations()); + auto &output_classes = allocation(node.output_classes()); + auto &output_scores = allocation(node.output_scores()); + auto &output_num_detections = allocation(node.output_num_detections()); + + builder.lea_buffer(box); + builder.lea_buffer(score); + builder.lea_buffer(anchor); + builder.lea_buffer(output_locations); + builder.lea_buffer(output_classes); + builder.lea_buffer(output_scores); + builder.lea_buffer(output_num_detections); + + builder.stshape(0, box.shape); + builder.stshape(1, score.shape); + builder.stshape(2, anchor.shape); + + builder.tensor_tflite_detection_postprocess_(0, 1, 2, node.max_detections(), node.max_classes_per_detection(), node.detections_per_class(), + node.use_regular_non_max_suppression(), node.nms_score_threshold(), node.nms_iou_threshold(), + node.num_classes(), node.y_scale(), node.x_scale(), node.h_scale(), node.w_scale()); +} diff --git a/src/evaluator/ops/neutral/neutral_ops.cpp b/src/evaluator/ops/neutral/neutral_ops.cpp index 762df27e1f..90b0f7a116 100644 --- a/src/evaluator/ops/neutral/neutral_ops.cpp +++ b/src/evaluator/ops/neutral/neutral_ops.cpp @@ -49,6 +49,7 @@ #include #include #include +#include #include #include #include @@ -819,6 +820,22 @@ void register_neutral_evaluators() B.buffer().as_span().data(), initial_h.buffer().as_span().data(), output.buffer().as_span().data(), output_h.buffer().as_span().data(), input.shape(), W.shape(), rnode.direction()) .unwrap_or_throw(); }); + + register_evaluator(op_tflite_detection_postprocess, [](ir::node &node, function_evaluate_context &context) { + auto &rnode = static_cast(node); + auto box = context.memory_at(rnode.boxes()); + auto score = context.memory_at(rnode.scores()); + auto anchor = context.memory_at(rnode.anchors()); + auto output_locations = context.memory_at(rnode.output_locations()); + auto output_classes = context.memory_at(rnode.output_classes()); + auto output_scores = context.memory_at(rnode.output_scores()); + auto output_num_detections = context.memory_at(rnode.output_num_detections()); + kernels::tflite_detection_postprocess(box.buffer().as_span().data(), score.buffer().as_span().data(), anchor.buffer().as_span().data(), + output_locations.buffer().as_span().data(), output_classes.buffer().as_span().data(), output_scores.buffer().as_span().data(), output_num_detections.buffer().as_span().data(), + box.shape(), score.shape(), anchor.shape(), rnode.max_detections(), rnode.max_classes_per_detection(), + rnode.detections_per_class(), rnode.use_regular_non_max_suppression(), rnode.nms_score_threshold(), rnode.nms_iou_threshold(), + rnode.num_classes(), rnode.y_scale(), rnode.x_scale(), rnode.h_scale(), rnode.w_scale()) + .unwrap_or_throw(); }); } } diff --git a/src/importer/tflite/ops/custom.cpp b/src/importer/tflite/ops/custom.cpp index 2637db2557..3ef635d077 100644 --- a/src/importer/tflite/ops/custom.cpp +++ b/src/importer/tflite/ops/custom.cpp @@ -15,6 +15,7 @@ #include "../tflite_importer.h" #include #include +#include using namespace nncase; using namespace nncase::importer; @@ -43,6 +44,53 @@ DEFINE_TFLITE_LOWER(CUSTOM) node->name(output.name()->string_view()); link_output_tensor(op.outputs()->Get(0), &node->output()); } + else if (custom_code == "TFLite_Detection_PostProcess") + { + auto &input_decoded_boxes = get_tensor(op.inputs(), 0); + auto &input_scores = get_tensor(op.inputs(), 1); + auto &input_anchors = get_tensor(op.inputs(), 2); + + // get_shape(output_x.shape()): get error shape, ignore it in this step. fix it in independent transform + auto &output_locations = get_tensor(op.outputs(), 0); //detection_boxes (1, num_detected_boxes, 4) + auto &output_classes = get_tensor(op.outputs(), 1); //detection_classes (1, num_detected_boxes) + auto &output_scores = get_tensor(op.outputs(), 2); //detection_scores (1, num_detected_boxes) + auto &output_num_detections = get_tensor(op.outputs(), 3); //num_detections (1) + + auto custom_options = op.custom_options(); + + const auto &m = flexbuffers::GetRoot(custom_options->data(), custom_options->size()).AsMap(); + auto max_detections = m["max_detections"].AsInt32(); + auto max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + + int32_t detections_per_class = 100; + if (!m["detections_per_class"].IsNull()) + detections_per_class = m["detections_per_class"].AsInt32(); + + bool use_regular_non_max_suppression = false; + if (!m["use_regular_nms"].IsNull()) + use_regular_non_max_suppression = m["use_regular_nms"].AsBool(); + + auto non_max_suppression_score_threshold = m["nms_score_threshold"].AsFloat(); + auto intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat(); + auto num_classes = m["num_classes"].AsInt32(); + auto y = m["y_scale"].AsFloat(); + auto x = m["x_scale"].AsFloat(); + auto h = m["h_scale"].AsFloat(); + auto w = m["w_scale"].AsFloat(); + + auto node = graph_.emplace(get_shape(input_decoded_boxes.shape()), get_shape(input_scores.shape()), get_shape(input_anchors.shape()), + get_shape(output_locations.shape()), get_shape(output_classes.shape()), get_shape(output_scores.shape()), get_shape(output_num_detections.shape()), + max_detections, max_classes_per_detection, detections_per_class, use_regular_non_max_suppression, non_max_suppression_score_threshold, + intersection_over_union_threshold, num_classes, y, x, h, w); + + link_input_tensor(&node->boxes(), op.inputs()->Get(0)); + link_input_tensor(&node->scores(), op.inputs()->Get(1)); + link_input_tensor(&node->anchors(), op.inputs()->Get(2)); + link_output_tensor(op.outputs()->Get(0), &node->output_locations()); + link_output_tensor(op.outputs()->Get(1), &node->output_classes()); + link_output_tensor(op.outputs()->Get(2), &node->output_scores()); + link_output_tensor(op.outputs()->Get(3), &node->output_num_detections()); + } else { throw std::runtime_error(std::string("Unsupported tflite CUSTOM code: ") + custom_code); diff --git a/src/ir/ops/CMakeLists.txt b/src/ir/ops/CMakeLists.txt index 1d0f9c9d82..37c66b60a7 100644 --- a/src/ir/ops/CMakeLists.txt +++ b/src/ir/ops/CMakeLists.txt @@ -46,4 +46,5 @@ target_sources(ir PRIVATE ternary.cpp topk.cpp trilu.cpp - gru.cpp) + gru.cpp + tflite_detection_postprocess.cpp) diff --git a/src/ir/ops/tflite_detection_postprocess.cpp b/src/ir/ops/tflite_detection_postprocess.cpp new file mode 100644 index 0000000000..78b91ff94f --- /dev/null +++ b/src/ir/ops/tflite_detection_postprocess.cpp @@ -0,0 +1,57 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +using namespace nncase; +using namespace nncase::ir; + +tflite_detection_postprocess::tflite_detection_postprocess( + shape_t boxes_shape, shape_t scores_shape, shape_t anchors_shape, + shape_t output_shape_0, shape_t output_shape_1, shape_t output_shape_2, shape_t output_shape_3, + int32_t max_detections, + int32_t max_classes_per_detection, + int32_t detections_per_class, + bool use_regular_non_max_suppression, + float nms_score_threshold, + float nms_iou_threshold, + int32_t num_classes, + float y_scale, + float x_scale, + float h_scale, + float w_scale) + : max_detections_(max_detections), max_classes_per_detection_(max_classes_per_detection), detections_per_class_(detections_per_class), use_regular_non_max_suppression_(use_regular_non_max_suppression), nms_score_threshold_(nms_score_threshold), nms_iou_threshold_(nms_iou_threshold), num_classes_(num_classes), y_scale_(y_scale), x_scale_(x_scale), h_scale_(h_scale), w_scale_(w_scale) +{ + add_input("boxes", dt_float32, boxes_shape); + add_input("scores", dt_float32, scores_shape); + add_input("anchors", dt_float32, anchors_shape); + add_output("output_locations", dt_float32, output_shape_0); + add_output("output_classes", dt_float32, output_shape_1); + add_output("output_scores", dt_float32, output_shape_2); + add_output("output_num_detections", dt_float32, output_shape_3); +} + +bool tflite_detection_postprocess::properties_equal(node &other) const +{ + auto &r = static_cast(other); + return max_detections() == r.max_detections() + && max_classes_per_detection() == r.max_classes_per_detection() + && detections_per_class() == r.detections_per_class() + && use_regular_non_max_suppression() == r.use_regular_non_max_suppression() + && nms_score_threshold() == r.nms_score_threshold() + && nms_iou_threshold() == r.nms_iou_threshold() && num_classes() == r.num_classes() + && y_scale() == r.y_scale() && x_scale() == r.x_scale() && h_scale() == r.h_scale() && w_scale() == r.w_scale(); +} diff --git a/src/kernels/cpu/reference/CMakeLists.txt b/src/kernels/cpu/reference/CMakeLists.txt index 697797a49d..9ea09ae9d6 100644 --- a/src/kernels/cpu/reference/CMakeLists.txt +++ b/src/kernels/cpu/reference/CMakeLists.txt @@ -34,5 +34,6 @@ set(SRCS batch_to_space.cpp topk.cpp transpose.cpp trilu.cpp + tflite_detection_postprocess.cpp unary.cpp) target_sources(kernels PRIVATE ${SRCS}) diff --git a/src/kernels/cpu/reference/tflite_detection_postprocess.cpp b/src/kernels/cpu/reference/tflite_detection_postprocess.cpp new file mode 100644 index 0000000000..05251d1d6a --- /dev/null +++ b/src/kernels/cpu/reference/tflite_detection_postprocess.cpp @@ -0,0 +1,376 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +using namespace nncase; +using namespace nncase::runtime; +using namespace nncase::kernels; +using namespace nncase::kernels::cpu; +using namespace nncase::kernels::cpu::reference; + +template result reference::tflite_detection_postprocess(const float *boxes, const float *scores, const float *anchors, float *output_locations, float *output_classes, float *output_scores, float *output_num_detections, + const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, const runtime_shape_t &anchors_shape, + const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class, + const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold, + const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) noexcept; + +template +result reference::tflite_detection_postprocess(const T *boxes, const T *scores, const T *anchors, T *output_locations, T *output_classes, T *output_scores, T *output_num_detections, + const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, const runtime_shape_t &anchors_shape, + const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class, + const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold, + const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) noexcept +{ + struct CenterSizeEncoding + { + float y; + float x; + float h; + float w; + }; + struct BoxCornerEncoding + { + float ymin; + float xmin; + float ymax; + float xmax; + }; + struct BoxInfo + { + int index; + float score; + }; + + auto compute_iou = [&](const std::vector &box, const int &i, const int &j) { + auto &box_i = box[i]; + auto &box_j = box[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) + return 0.f; + const float intersection_y_min = std::max(box_i.ymin, box_j.ymin); + const float intersection_x_min = std::max(box_i.xmin, box_j.xmin); + const float intersection_y_max = std::min(box_i.ymax, box_j.ymax); + const float intersection_x_max = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = std::max(intersection_y_max - intersection_y_min, 0.0) * std::max(intersection_x_max - intersection_x_min, 0.0); + return intersection_area / (area_i + area_j - intersection_area); + }; + + const auto num_boxes = (int)anchors_shape[0]; + const auto num_classes_with_background = (int)scores_shape[2]; // num_classes + background + const auto num_detections_per_class = std::min(detections_per_class, max_detections); + int label_offset = num_classes_with_background - num_classes; + // DecodeCenterSizeBoxes: get decoded_boxes + std::vector decoded_boxes(boxes_shape[1]); + { + CenterSizeEncoding box_center_size; + CenterSizeEncoding scale_values { y_scale, x_scale, h_scale, w_scale }; + CenterSizeEncoding anchor; + + for (int index = 0; index < num_boxes; index++) + { + const auto box_encoding_index = index * boxes_shape[2]; + box_center_size = *reinterpret_cast(boxes + box_encoding_index); + anchor = *reinterpret_cast(anchors + box_encoding_index); + + auto y_center = static_cast(static_cast(box_center_size.y) / static_cast(scale_values.y) * static_cast(anchor.h) + static_cast(anchor.y)); + auto x_center = static_cast(static_cast(box_center_size.x) / static_cast(scale_values.x) * static_cast(anchor.w) + static_cast(anchor.x)); + auto half_h = static_cast(0.5 * (std::exp(static_cast(box_center_size.h) / static_cast(scale_values.h))) * static_cast(anchor.h)); + auto half_w = static_cast(0.5 * (std::exp(static_cast(box_center_size.w) / static_cast(scale_values.w))) * static_cast(anchor.w)); + decoded_boxes[index].ymin = y_center - half_h; + decoded_boxes[index].xmin = x_center - half_w; + decoded_boxes[index].ymax = y_center + half_h; + decoded_boxes[index].xmax = x_center + half_w; + } + } + // NMS MultiClass + { + if (use_regular_non_max_suppression) + { + // NMS Regular + int sorted_indices_size = 0; + std::vector box_info_after_regular_nms(max_detections + num_detections_per_class); + std::vector num_selected(num_classes); + + // compute nms + std::vector class_scores(num_boxes); + std::vector selected; + selected.reserve(num_detections_per_class); + + for (auto col = 0; col < num_classes - 1; col++) + { + const float *scores_base = scores + col + label_offset; + for (int row = 0; row < num_boxes; row++) + { + // Get scores of boxes corresponding to all anchors for single class + class_scores[row] = *scores_base; + scores_base += num_classes_with_background; + } + // Perform non-maximal suppression on single class + selected.clear(); + + // NMS SingleClass + { + std::vector keep_indices; + std::vector keep_scores; + // select detection box score above score threshold + { + for (size_t i = 0; i < class_scores.size(); i++) + { + if (class_scores[i] >= nms_score_threshold) + { + keep_scores.emplace_back(class_scores[i]); + keep_indices.emplace_back(i); + } + } + } + + int num_scores_kept = (int)keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + // DecreasingArgSort + { + std::iota(sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, 0); + std::stable_sort( + sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, + [&keep_scores](const int i, const int j) { return keep_scores[i] > keep_scores[j]; }); + } + + const int output_size = std::min(num_scores_kept, max_detections); + selected.clear(); + int num_active_candidate = num_scores_kept; + std::vector active_box_candidate(num_scores_kept, 1); + for (int i = 0; i < num_scores_kept; ++i) + { + if (num_active_candidate == 0 || (int)selected.size() >= output_size) + break; + if (active_box_candidate[i] == 1) + { + selected.push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } + else + { + continue; + } + for (int j = i + 1; j < num_scores_kept; ++j) + { + if (active_box_candidate[j] == 1) + { + + float iou = compute_iou( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (iou > nms_iou_threshold) + { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + } + // end NMS SingleClass + + if (selected.empty()) + { + continue; + } + for (size_t i = 0; i < selected.size(); ++i) + { + box_info_after_regular_nms[sorted_indices_size + i].score = class_scores[selected[i]]; + box_info_after_regular_nms[sorted_indices_size + i].index = (selected[i] * num_classes_with_background + col + label_offset); + } + + // In-place merge the original boxes and new selected boxes which are both + // sorted by scores. + std::inplace_merge(box_info_after_regular_nms.begin(), box_info_after_regular_nms.begin() + sorted_indices_size, + box_info_after_regular_nms.begin() + sorted_indices_size + selected.size(), + [](const BoxInfo &a, const BoxInfo &b) { return a.score >= b.score; }); + + sorted_indices_size = std::min(sorted_indices_size + static_cast(selected.size()), max_detections); + } + // end compute nms result + + // Allocate output tensors + for (int output_box_index = 0; output_box_index < max_detections; output_box_index++) + { + if (output_box_index < sorted_indices_size) + { + const int anchor_index = floor( + box_info_after_regular_nms[output_box_index].index / num_classes_with_background); + const int class_index = box_info_after_regular_nms[output_box_index].index - anchor_index * num_classes_with_background - label_offset; + const float selected_score = box_info_after_regular_nms[output_box_index].score; + // detection_boxes + reinterpret_cast(output_locations)[output_box_index] = decoded_boxes[anchor_index]; + // detection_classes + output_classes[output_box_index] = class_index; + // detection_scores + output_scores[output_box_index] = selected_score; + } + else + { + // detection_boxes + reinterpret_cast(output_locations)[output_box_index] = { 0.0f, 0.0f, 0.0f, 0.0f }; + // detection_classes + output_classes[output_box_index] = 0.0f; + // detection_scores + output_scores[output_box_index] = 0.0f; + } + } + output_num_detections[0] = sorted_indices_size; + box_info_after_regular_nms.clear(); + } + else + { + // Fast NMS + + const int max_categories_per_anchor = max_classes_per_detection; + const int num_categories_per_anchor = std::min(max_categories_per_anchor, num_classes); + + std::vector max_scores; + max_scores.resize(num_boxes); + std::vector sorted_class_indices; + sorted_class_indices.resize(num_boxes * num_categories_per_anchor); + + for (int row = 0; row < num_boxes; row++) + { + const float *box_scores = scores + row * num_classes_with_background + label_offset; + int *class_indices = sorted_class_indices.data() + row * num_categories_per_anchor; + + // DecreasingPartialArgSort + if (num_categories_per_anchor == 1) + { + auto arg_max_vector = [&](const T *input_data, int size) { + T max_value = input_data[0]; + int max_index = 0; + for (int i = 1; i < size; ++i) + { + // const T curr_value = input_data[i]; + if (input_data[i] > max_value) + { + max_value = input_data[i]; + max_index = i; + } + } + return max_index; + }; + class_indices[0] = arg_max_vector(box_scores, num_classes); + } + else + { + std::iota(class_indices, class_indices + num_classes, 0); + std::partial_sort( + class_indices, class_indices + num_categories_per_anchor, class_indices + num_classes, + [&box_scores](const int i, const int j) { return box_scores[i] > box_scores[j]; }); + } + // end DecreasingPartialArgSort + + max_scores[row] = box_scores[class_indices[0]]; + } + std::vector selected; + // NMS SingleClass + { + std::vector keep_indices; + std::vector keep_scores; + // select detection box score above score threshold + { + for (size_t i = 0; i < max_scores.size(); i++) + { + if (max_scores[i] >= nms_score_threshold) + { + keep_scores.emplace_back(max_scores[i]); + keep_indices.emplace_back(i); + } + } + } + + int num_scores_kept = (int)keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + // DecreasingArgSort + { + std::iota(sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, 0); + std::stable_sort( + sorted_indices.begin(), sorted_indices.begin() + num_scores_kept, + [&keep_scores](const int i, const int j) { return keep_scores[i] > keep_scores[j]; }); + } + const int output_size = std::min(num_scores_kept, max_detections); + selected.clear(); + int num_active_candidate = num_scores_kept; + std::vector active_box_candidate(num_scores_kept, 1); + for (int i = 0; i < num_scores_kept; ++i) + { + if (num_active_candidate == 0 || (int)selected.size() >= output_size) + break; + if (active_box_candidate[i] == 1) + { + selected.push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } + else + { + continue; + } + for (int j = i + 1; j < num_scores_kept; ++j) + { + if (active_box_candidate[j] == 1) + { + + float iou = compute_iou( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + if (iou > nms_iou_threshold) + { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + } + // end NMS SingleClass + + // Allocate output tensors + int output_box_index = 0; + for (const auto &selected_index : selected) + { + const float *box_scores = scores + selected_index * num_classes_with_background + label_offset; + const int *class_indices = sorted_class_indices.data() + selected_index * num_categories_per_anchor; + + for (int col = 0; col < num_categories_per_anchor; ++col) + { + int box_offset = max_categories_per_anchor * output_box_index + col; + // detection_boxes + reinterpret_cast(output_locations)[box_offset] = decoded_boxes[selected_index]; + // detection_classes + output_classes[box_offset] = class_indices[col]; + // detection_scores + output_scores[box_offset] = box_scores[class_indices[col]]; + } + output_box_index++; + } + output_num_detections[0] = output_box_index; + } + } + + return ok(); +} diff --git a/src/kernels/tensor_compute.cpp b/src/kernels/tensor_compute.cpp index 7d83eacfce..ac8d5ea702 100644 --- a/src/kernels/tensor_compute.cpp +++ b/src/kernels/tensor_compute.cpp @@ -474,4 +474,24 @@ template result kernels::gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept { return cpu::reference::gru(input, w, r, b, initial_h, output, output_h, input_shape, w_shape, mode); +} + +template result kernels::tflite_detection_postprocess(const float *boxes, const float *scores, const float *anchors, float *output_locations, float *output_classes, float *output_scores, float *output_num_detections, + const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, const runtime_shape_t &anchors_shape, + const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class, + const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold, + const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) noexcept; + +template +result kernels::tflite_detection_postprocess(const T *boxes, const T *scores, const T *anchors, T *output_locations, T *output_classes, T *output_scores, T *output_num_detections, + const runtime_shape_t &boxes_shape, const runtime_shape_t &scores_shape, const runtime_shape_t &anchors_shape, + const int32_t max_detections, const int32_t max_classes_per_detection, const int32_t detections_per_class, + const bool use_regular_non_max_suppression, const float nms_score_threshold, const float nms_iou_threshold, + const int32_t num_classes, const float y_scale, const float x_scale, const float h_scale, const float w_scale) noexcept +{ + return cpu::reference::tflite_detection_postprocess(boxes, scores, anchors, output_locations, output_classes, output_scores, output_num_detections, + boxes_shape, scores_shape, anchors_shape, + max_detections, max_classes_per_detection, detections_per_class, + use_regular_non_max_suppression, nms_score_threshold, nms_iou_threshold, + num_classes, y_scale, x_scale, h_scale, w_scale); } \ No newline at end of file diff --git a/src/runtime/stackvm/CMakeLists.txt b/src/runtime/stackvm/CMakeLists.txt index 4c69ca65c5..1b2866f4f3 100644 --- a/src/runtime/stackvm/CMakeLists.txt +++ b/src/runtime/stackvm/CMakeLists.txt @@ -1,49 +1,50 @@ -cmake_minimum_required (VERSION 3.13) +cmake_minimum_required(VERSION 3.13) set(SRCS runtime_module.cpp - runtime_function.cpp - op_reader.cpp - evaluate_stack.cpp - ops/control.cpp - ops/loadstore.cpp - ops/stack.cpp - ops/scalar.cpp - ops/conversion.cpp - ops/tensor.batch_to_space.cpp - ops/tensor.binary.cpp - ops/tensor.broadcast.cpp - ops/tensor.call.cpp - ops/tensor.compare.cpp - ops/tensor.conv2d.cpp - ops/tensor.convert.cpp - ops/tensor.copy.cpp - ops/tensor.cumsum.cpp - ops/tensor.dequantize.cpp - ops/tensor.gather.cpp - ops/tensor.gather_nd.cpp - ops/tensor.gru.cpp - ops/tensor.hardmax.cpp - ops/tensor.lut1d.cpp - ops/tensor.matmul.cpp - ops/tensor.onehot.cpp - ops/tensor.pad.cpp - ops/tensor.quantize.cpp - ops/tensor.random_normal.cpp - ops/tensor.random_uniform.cpp - ops/tensor.reduce.cpp - ops/tensor.reduce_arg.cpp - ops/tensor.reduce_prod.cpp - ops/tensor.reduce_window2d.cpp - ops/tensor.resize_image.cpp - ops/tensor.roi_align.cpp - ops/tensor.sigmoid.cpp - ops/tensor.slice.cpp - ops/tensor.softmax.cpp - ops/tersor.ternary.cpp - ops/tensor.topk.cpp - ops/tensor.transpose.cpp - ops/tensor.trilu.cpp - ops/tensor.unary.cpp) + runtime_function.cpp + op_reader.cpp + evaluate_stack.cpp + ops/control.cpp + ops/loadstore.cpp + ops/stack.cpp + ops/scalar.cpp + ops/conversion.cpp + ops/tensor.batch_to_space.cpp + ops/tensor.binary.cpp + ops/tensor.broadcast.cpp + ops/tensor.call.cpp + ops/tensor.compare.cpp + ops/tensor.conv2d.cpp + ops/tensor.convert.cpp + ops/tensor.copy.cpp + ops/tensor.cumsum.cpp + ops/tensor.dequantize.cpp + ops/tensor.gather.cpp + ops/tensor.gather_nd.cpp + ops/tensor.gru.cpp + ops/tensor.hardmax.cpp + ops/tensor.lut1d.cpp + ops/tensor.matmul.cpp + ops/tensor.onehot.cpp + ops/tensor.pad.cpp + ops/tensor.quantize.cpp + ops/tensor.random_normal.cpp + ops/tensor.random_uniform.cpp + ops/tensor.reduce.cpp + ops/tensor.reduce_arg.cpp + ops/tensor.reduce_prod.cpp + ops/tensor.reduce_window2d.cpp + ops/tensor.resize_image.cpp + ops/tensor.roi_align.cpp + ops/tensor.sigmoid.cpp + ops/tensor.slice.cpp + ops/tensor.softmax.cpp + ops/tersor.ternary.cpp + ops/tensor.topk.cpp + ops/tensor.transpose.cpp + ops/tensor.trilu.cpp + ops/tensor.tflite_detection_postprocess.cpp + ops/tensor.unary.cpp) if (BUILDING_RUNTIME) add_library(runtime_stackvm OBJECT ${SRCS}) @@ -51,9 +52,9 @@ if (BUILDING_RUNTIME) target_link_libraries(runtime_stackvm PRIVATE kernels) set_property(TARGET runtime_stackvm PROPERTY POSITION_INDEPENDENT_CODE ON) install(TARGETS runtime_stackvm EXPORT nncaseruntimeTargets) -else() +else () add_library(simulator_stackvm OBJECT ${SRCS}) target_link_libraries(simulator_stackvm PUBLIC simulator) target_link_libraries(simulator_stackvm PRIVATE kernels) set_property(TARGET simulator_stackvm PROPERTY POSITION_INDEPENDENT_CODE ON) -endif() +endif () diff --git a/src/runtime/stackvm/op_reader.cpp b/src/runtime/stackvm/op_reader.cpp index 3c7aba23be..c9426627c0 100644 --- a/src/runtime/stackvm/op_reader.cpp +++ b/src/runtime/stackvm/op_reader.cpp @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 6/9/2022 4:13:50 PM +08:00. +/* This file is generated by tools/stackvm_gen/IsaGen at 6/30/2022 4:30:43 PM +08:00. * * Copyright 2019-2021 Canaan Inc. * @@ -113,13 +113,6 @@ result op_visitor::next() noexcept #endif return visit(op_reader()(reader_)); } - case tensor_function_t::GRU: - { -#if defined ENABLE_OP_PROFILE - op_profile st("tensor_gru"); -#endif - return visit(op_reader()(reader_)); - } case tensor_function_t::HARDMAX: { #if defined ENABLE_OP_PROFILE @@ -274,6 +267,20 @@ result op_visitor::next() noexcept #endif return visit(op_reader()(reader_)); } + case tensor_function_t::GRU: + { +#if defined ENABLE_OP_PROFILE + op_profile st("tensor_gru"); +#endif + return visit(op_reader()(reader_)); + } + case tensor_function_t::TFLITE_DETECTION_POSTPROCESS: + { +#if defined ENABLE_OP_PROFILE + op_profile st("tensor_tflite_detection_postprocess"); +#endif + return visit(op_reader()(reader_)); + } default: break; } diff --git a/src/runtime/stackvm/ops/tensor.tflite_detection_postprocess.cpp b/src/runtime/stackvm/ops/tensor.tflite_detection_postprocess.cpp new file mode 100644 index 0000000000..633b6305fa --- /dev/null +++ b/src/runtime/stackvm/ops/tensor.tflite_detection_postprocess.cpp @@ -0,0 +1,44 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "../runtime_function.h" +#include +#include +#include + +using namespace nncase; +using namespace nncase::runtime; +using namespace nncase::runtime::stackvm; + +result stackvm_runtime_function::visit(const tensor_tflite_detection_postprocess_op_t &op) noexcept +{ + try_var(output_num_detections, pop_addr()); + try_var(output_scores, pop_addr()); + try_var(output_classes, pop_addr()); + try_var(output_locations, pop_addr()); + try_var(anchor, pop_addr()); + try_var(score, pop_addr()); + try_var(box, pop_addr()); + + try_var(box_shape, module().shape_reg(op.box_shape_src)); + try_var(score_shape, module().shape_reg(op.score_shape_src)); + try_var(anchor_shape, module().shape_reg(op.anchor_shape_src)); + + return kernels::tflite_detection_postprocess(reinterpret_cast(box), reinterpret_cast(score), + reinterpret_cast(anchor), reinterpret_cast(output_locations), + reinterpret_cast(output_classes), reinterpret_cast(output_scores), + reinterpret_cast(output_num_detections), box_shape, score_shape, anchor_shape, op.max_detections, op.max_classes_per_detection, op.detections_per_class, + op.use_regular_non_max_suppression, op.nms_score_threshold, op.nms_iou_threshold, + op.num_classes, op.y_scale, op.x_scale, op.h_scale, op.w_scale); +} diff --git a/src/runtime/stackvm/runtime_function.h b/src/runtime/stackvm/runtime_function.h index 8175805696..c0ec79c95b 100644 --- a/src/runtime/stackvm/runtime_function.h +++ b/src/runtime/stackvm/runtime_function.h @@ -171,6 +171,7 @@ class stackvm_runtime_function : public runtime_function, private op_visitor result visit(const tensor_topk_op_t &op) noexcept override; result visit(const tensor_transpose_op_t &op) noexcept override; result visit(const tensor_trilu_op_t &op) noexcept override; + result visit(const tensor_tflite_detection_postprocess_op_t &op) noexcept override; result visit(const tensor_unary_op_t &op) noexcept override; private: diff --git a/src/targets/neutral_target.cpp b/src/targets/neutral_target.cpp index e6c6e6e1f2..29c5baffe8 100644 --- a/src/targets/neutral_target.cpp +++ b/src/targets/neutral_target.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -185,12 +186,20 @@ void neutral_target::register_target_independent_passes(const module_type_t &typ using namespace nncase::ir; using namespace nncase::ir::transforms; + // fix tflite_detection_postprocess shape error in tflite + { + transform_pass p("fix_shape_tdp"); + p.emplace(); + pass_mgr.add_pass(std::move(p)); + } + // fold quant node in source model { transform_pass p("fold_quantize_in_source_model"); p.emplace(); pass_mgr.add_pass(std::move(p)); } + if (type == runtime::stackvm::stackvm_module_type) { // fold_pad_conv diff --git a/src/transforms/neutral/CMakeLists.txt b/src/transforms/neutral/CMakeLists.txt index da6b09b628..37f85a1b22 100644 --- a/src/transforms/neutral/CMakeLists.txt +++ b/src/transforms/neutral/CMakeLists.txt @@ -48,4 +48,5 @@ target_sources(transforms PRIVATE merge_binary_before_conv.cpp fold_matmul_add.cpp squeeze_dims.cpp + fix_output_shape.cpp ) diff --git a/src/transforms/neutral/fix_output_shape.cpp b/src/transforms/neutral/fix_output_shape.cpp new file mode 100644 index 0000000000..6df7642621 --- /dev/null +++ b/src/transforms/neutral/fix_output_shape.cpp @@ -0,0 +1,93 @@ +/* Copyright 2019-2021 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +using namespace nncase; +using namespace nncase::ir; +using namespace nncase::ir::transforms; + +bool tflite_detection_postprocess_transform::on_try_match(node &node, transform_context &context) +{ + if (auto tdp = node_cast(node)) + { + if (tdp->output_locations().shape() == shape_t { 1, (size_t)tdp->max_detections(), 4 }) + return false; + context.inputs.emplace_back(&tdp->boxes()); + context.inputs.emplace_back(&tdp->scores()); + context.inputs.emplace_back(&tdp->anchors()); + + context.outputs.emplace_back(&tdp->output_locations()); + context.outputs.emplace_back(&tdp->output_classes()); + context.outputs.emplace_back(&tdp->output_scores()); + context.outputs.emplace_back(&tdp->output_num_detections()); + + context.matched_nodes.emplace_back(tdp); + return true; + } + + return false; +} + +void tflite_detection_postprocess_transform::process(transform_context &context) +{ + auto &box = *context.inputs[0]->connection(); + auto &score = *context.inputs[1]->connection(); + auto &anchor = *context.inputs[2]->connection(); + auto output_locations = context.outputs[0]->connections(); + auto output_classes = context.outputs[1]->connections(); + auto output_scores = context.outputs[2]->connections(); + auto output_num_detections = context.outputs[3]->connections(); + + auto &old_tdp = static_cast(*context.matched_nodes[0]); + shape_t new_output_shape_0 { 1, (size_t)old_tdp.max_detections(), 4 }; + shape_t new_output_shape_1 { 1, (size_t)old_tdp.max_detections() }; + shape_t new_output_shape_2 { 1, (size_t)old_tdp.max_detections() }; + shape_t new_output_shape_3 { 1 }; + + context.graph.outputs(); + auto new_output_node_0 = context.graph.emplace(output_locations[0]->type(), new_output_shape_0); + auto new_output_node_1 = context.graph.emplace(output_classes[0]->type(), new_output_shape_1); + auto new_output_node_2 = context.graph.emplace(output_scores[0]->type(), new_output_shape_2); + auto new_output_node_3 = context.graph.emplace(output_num_detections[0]->type(), new_output_shape_3); + new_output_node_0->name("output_locations"); + new_output_node_1->name("output_classes"); + new_output_node_2->name("output_scores"); + new_output_node_3->name("output_num_detections"); + + auto new_tdp = context.graph.emplace(old_tdp.boxes().shape(), old_tdp.scores().shape(), old_tdp.anchors().shape(), + new_output_shape_0, new_output_shape_1, new_output_shape_2, new_output_shape_3, old_tdp.max_detections(), old_tdp.max_classes_per_detection(), + old_tdp.detections_per_class(), old_tdp.use_regular_non_max_suppression(), old_tdp.nms_score_threshold(), old_tdp.nms_iou_threshold(), + old_tdp.num_classes(), old_tdp.y_scale(), old_tdp.x_scale(), old_tdp.h_scale(), old_tdp.w_scale()); + new_tdp->name(old_tdp.name()); + + for (auto &i : context.graph.outputs()) + { + i->input().clear_connection(); + } + + new_tdp->boxes().connect(box); + new_tdp->scores().connect(score); + new_tdp->anchors().connect(anchor); + + new_output_node_0->input().connect(new_tdp->output_locations()); + new_output_node_1->input().connect(new_tdp->output_classes()); + new_output_node_2->input().connect(new_tdp->output_scores()); + new_output_node_3->input().connect(new_tdp->output_num_detections()); + + context.graph.dce(); +} \ No newline at end of file diff --git a/tools/stackvm_gen/IsaGen/Instructions.cs b/tools/stackvm_gen/IsaGen/Instructions.cs index 1d55ae47a2..21cf4e3a43 100644 --- a/tools/stackvm_gen/IsaGen/Instructions.cs +++ b/tools/stackvm_gen/IsaGen/Instructions.cs @@ -158,7 +158,6 @@ public enum TensorFunction DEQUANTIZE, GATHER, GATHER_ND, - GRU, HARDMAX, LOGISTIC, LUT1D, @@ -184,6 +183,8 @@ public enum TensorFunction TRANSPOSE, TRILU, UNARY, + GRU, + TFLITE_DETECTION_POSTPROCESS, } [BitLength(8)] @@ -1509,27 +1510,6 @@ public class GatherNDInstruction : TensorInstruction public byte Batchdims { get; set; } } - [DisplayName("TENSOR.GRU")] - [Category("Tensor Instructions")] - [Description("Gru")] - public class GruInstruction : TensorInstruction - { - public override TensorFunction Function => TensorFunction.GRU; - - [DisplayName("input_shape_src")] - [Description("Input shape register")] - public byte RshapeSrc1 { get; set; } - - [DisplayName("w_shape_src")] - [Description("W shape register")] - public byte RshapeSrc2 { get; set; } - - [DisplayName("direction")] - [Description("direction register")] - public byte Direction { get; set; } - - } - [DisplayName("TENSOR.HARDMAX")] [Category("Tensor Instructions")] [Description("Hardmax")] @@ -2273,5 +2253,88 @@ public class TransposeInstruction : TensorInstruction [Description("Perm shape register")] public byte RshapePerm { get; set; } } + [DisplayName("TENSOR.GRU")] + [Category("Tensor Instructions")] + [Description("Gru")] + public class GruInstruction : TensorInstruction + { + public override TensorFunction Function => TensorFunction.GRU; + + [DisplayName("input_shape_src")] + [Description("Input shape register")] + public byte RshapeSrc1 { get; set; } + + [DisplayName("w_shape_src")] + [Description("W shape register")] + public byte RshapeSrc2 { get; set; } + + [DisplayName("direction")] + [Description("direction register")] + public byte Direction { get; set; } + + } + [DisplayName("TENSOR.TFLITE_DETECTION_POSTPROCESS")] + [Category("Tensor Instructions")] + [Description("Tflite_Detection_Postprocess")] + public class TfliteDetectionPostprocessInstruction : TensorInstruction + { + public override TensorFunction Function => TensorFunction.TFLITE_DETECTION_POSTPROCESS; + + [DisplayName("box_shape_src")] + [Description("Box shape register")] + public byte RshapeSrc1 { get; set; } + + [DisplayName("score_shape_src")] + [Description("Score shape register")] + public byte RshapeSrc2 { get; set; } + + [DisplayName("anchor_shape_src")] + [Description("Anchor shape register")] + public byte RshapeSrc3 { get; set; } + + [DisplayName("max_detections")] + [Description("max_detections register")] + public int MaxDetections { get; set; } + + [DisplayName("max_classes_per_detection")] + [Description("max_classes_per_detection register")] + public int MaxClassesPerDetection { get; set; } + + [DisplayName("detections_per_class")] + [Description("detections_per_class register")] + public int DetectionsPerClass { get; set; } + + [DisplayName("use_regular_non_max_suppression")] + [Description("use_regular_non_max_suppression register")] + public bool UseRegularNonMaxSuppression { get; set; } + + [DisplayName("nms_score_threshold")] + [Description("nms_score_threshold register")] + public float NmsScoreThreshold { get; set; } + + [DisplayName("nms_iou_threshold")] + [Description("nms_iou_threshold register")] + public float NmsIouThreshold { get; set; } + + [DisplayName("num_classes")] + [Description("num_classes register")] + public int NumClasses { get; set; } + + [DisplayName("y_scale")] + [Description("y_scale register")] + public float YScale { get; set; } + + [DisplayName("x_scale")] + [Description("x_scale register")] + public float XScale { get; set; } + + [DisplayName("h_scale")] + [Description("h_scale register")] + public float HScale { get; set; } + + [DisplayName("w_scale")] + [Description("w_scale register")] + public float WScale { get; set; } + } } }