Skip to content

Commit

Permalink
Merge branch 'lp_normalization' into 'master'
Browse files Browse the repository at this point in the history
add lpnorm、mvnorm op for caffe, enhance biasadd、reshape op

See merge request !1224
  • Loading branch information
yejw5 committed Nov 7, 2019
2 parents 8a0abca + 72a3751 commit a610d50
Show file tree
Hide file tree
Showing 32 changed files with 1,658 additions and 66 deletions.
4 changes: 3 additions & 1 deletion mace/core/memory_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ MemoryBlock MemoryOptimizer::CreateMemoryBlock(
if (shape.size() == 2) {
shape = {shape[0], 1, 1, shape[1]};
} else {
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input";
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input, "
<< "op name: " << op_def->name() << ", "
<< MakeString(shape);
}
OpenCLUtil::CalImage2DShape(shape, buffer_type, &image_shape);
block.set_x(image_shape[0]);
Expand Down
68 changes: 48 additions & 20 deletions mace/ops/arm/fp32/bias_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,34 +62,62 @@ void BiasAdd::AddBias(const OpContext *context,

utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b * channels + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
float32x4_t vbias = vdupq_n_f32(bias);
if (bias->dim_size() == 1) {
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
const index_t b_offset = b * channels;
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b_offset + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
float32x4_t vbias = vdupq_n_f32(bias);

for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);

input_ptr += 4;
output_ptr += 4;
input_ptr += 4;
output_ptr += 4;
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}, 0, batch, 1, 0, channels, 1);
} else {
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
const index_t b_offset = b * channels;
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b_offset + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[b * channels + c];
float32x4_t vbias = vdupq_n_f32(bias);

for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);

input_ptr += 4;
output_ptr += 4;
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}
}
}
}, 0, batch, 1, 0, channels, 1);
}, 0, batch, 1, 0, channels, 1);
}
}

} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace

79 changes: 57 additions & 22 deletions mace/ops/bias_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,18 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *bias = this->Input(1);

MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size());

Tensor *output = this->Output(0);

if (input->dim_size() == 4 && has_data_format_) {
if (input->dim_size() == 4 && (has_data_format_
|| input->data_format() == DataFormat::NCHW)) { // NCHW
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1-dimensional or n*c for caffee.",
MakeString(bias->shape()));
bias_add_delegator_.Compute(context, input, bias, output);
} else {
} else { // NHWC
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1 or 2 dimensionals for caffee.",
bias->dim_size(), MakeString(bias->shape()));
// TODO(liyin): remove it and tranform bias to add (eltwise)
MACE_RETURN_IF_ERROR(output->ResizeLike(input));

Expand All @@ -70,16 +73,40 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
float *output_ptr = output->mutable_data<float>();

const std::vector<index_t> &shape = input->shape();
const index_t fused_batch = std::accumulate(
shape.begin(), shape.end() - 1, 1, std::multiplies<index_t>());
const index_t channels = *shape.rbegin();

for (index_t n = 0; n < fused_batch; ++n) {
index_t pos = n * channels;
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
}
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
if (bias->dim_size() == 1) {
const index_t fused_batch = std::accumulate(
shape.begin(), shape.end() - 1, 1, std::multiplies<index_t>());
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t n = start; n < end; n += step) {
index_t pos = n * channels;
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
}
}
}, 0, fused_batch, 1);
} else { // bias is 2d
const auto n = shape[0];
MACE_CHECK(n == bias->shape()[0]);
const index_t fused_hw = std::accumulate(
shape.begin() + 1, shape.end() - 1, 1, std::multiplies<index_t>());
const auto ch_size = bias->shape()[1];
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
auto offset = i * fused_hw;
auto bias_offset = i * ch_size;
for (index_t j = start1; j < end1; j += step1) {
index_t pos = (offset + i) * channels;
for (index_t c = 0; c < channels; ++c, ++pos) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[bias_offset + c];
}
}
}
}, 0, n, 1, 0, fused_hw, 1);
}
}

Expand Down Expand Up @@ -109,21 +136,25 @@ class BiasAddOp<DeviceType::GPU, float> : public Operation {
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS);

// for const bias tensor
if (context->workspace()->GetTensor(operator_def_->input(1)) != nullptr) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS, "TransformFilter failed");
}
}

MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0);
const Tensor *bias = this->Input(1);

MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size());

Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
MACE_CHECK(input->dim_size() == 4 && has_data_format_,
"gpu only support biasadd for 4-dimensional NHWC format tensor");
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1-dimensional or 2-dimensional for caffee. ",
MakeString(bias->shape()));
return kernel_->Compute(context, input, bias, output);
}

Expand Down Expand Up @@ -151,6 +182,10 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) {
*op, "has_data_format", 0);
if (!has_data_format ||
op->output_shape(0).dims_size() != 4) {
LOG(INFO) << "BiasAdd only support cpu, has_data_format="
<< has_data_format
<< ", op->output_shape(0).dims_size()="
<< op->output_shape(0).dims_size();
return {DeviceType::CPU};
}
return {DeviceType::CPU, DeviceType::GPU};
Expand Down
157 changes: 157 additions & 0 deletions mace/ops/lpnorm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include <functional>
#include <memory>

#include "mace/core/operator.h"

#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/lpnorm.h"
#endif // MACE_ENABLE_OPENCL

/**
* LpNormOp is a Normalization OP which support L1 and L2, which is a custom op
* of caffe (not exist in official caffe), please reference:
* https://github.com/freesouls/caffe/blob/master/src/caffe/layers/normalization_layer.cpp #noqa
*/

namespace mace {
namespace ops {

template<DeviceType D, typename T>
class LpNormOp;

template<>
class LpNormOp<DeviceType::CPU, float> : public Operation {
public:
explicit LpNormOp(OpConstructContext *context)
: Operation(context),
p_(Operation::GetOptionalArg<int>("p", 2)),
axis_(Operation::GetOptionalArg<int>("axis", -1)) {}

MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);

if (axis_ < 0) {
axis_ += input->dim_size();
}
MACE_CHECK(axis_ < input->dim_size() && axis_ >= 0,
"The axis_ must be small than dim size");
const std::vector<index_t> &input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->Resize(input_shape));

Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);

const auto *input_data = input->data<float>();
auto *output_data = output->mutable_data<float>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
auto outer_loop = std::accumulate(input_shape.begin(),
input_shape.begin() + axis_, 1,
std::multiplies<index_t>());
auto inner_loop = std::accumulate(input_shape.begin() + axis_,
input_shape.end(), 1,
std::multiplies<index_t>());

if (p_ == 1) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = std::abs(input_data[i]);
}
}, 0, input->size(), 1);
} else if (p_ == 2) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = input_data[i] * input_data[i];
}
}, 0, input->size(), 1);
} else {
LOG(FATAL) << "LpNorm's p should be 1 or 2, current p is: " << p_;
}

const float power = 1 / static_cast<float>(p_);
auto norm_buffer = context->device()->scratch_buffer();
norm_buffer->Rewind();
MACE_RETURN_IF_ERROR(norm_buffer->GrowSize(outer_loop * sizeof(float)));
float *norm_ptr = norm_buffer->mutable_data<float>();
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
auto output_data_base = output_data + inner_loop * i;
norm_ptr[i] = std::accumulate(output_data_base,
output_data_base + inner_loop, 0.0f);
norm_ptr[i] = std::pow(norm_ptr[i], power);
norm_ptr[i] += 1e-6;
}
}, 0, outer_loop, 1);

thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] = input_data[offset + j] / norm_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);

return MaceStatus::MACE_SUCCESS;
}

private:
int p_;
int axis_;
};

#ifdef MACE_ENABLE_OPENCL
template<>
class LpNormOp<DeviceType::GPU, float> : public Operation {
public:
explicit LpNormOp(OpConstructContext *context)
: Operation(context) {
const auto p = Operation::GetOptionalArg<int>("p", 2);
const auto axis = Operation::GetOptionalArg<int>("axis", -1);
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::LpNormKernel>(p, axis);
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));

return kernel_->Compute(context, input, output);
}

private:
std::unique_ptr<OpenCLLpNormKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_OPENCL

void RegisterLpNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "LpNorm", LpNormOp,
DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "LpNorm", LpNormOp);
}

} // namespace ops
} // namespace mace
Loading

0 comments on commit a610d50

Please sign in to comment.