Skip to content

Commit

Permalink
Refactor String and Audio operators with status-return prototype. (#576)
Browse files Browse the repository at this point in the history
* Refactor String and Audio operators with status-return prototype.

* complete the whole text domain

---------

Co-authored-by: Sayan Shaw <[email protected]>
  • Loading branch information
wenbingl and sayanshaw24 authored Oct 19, 2023
1 parent 4d2930e commit c71e2ae
Show file tree
Hide file tree
Showing 46 changed files with 551 additions and 609 deletions.
27 changes: 23 additions & 4 deletions includes/onnxruntime_customop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,25 @@ inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
}

template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
out.resize(size - 1); // remove the terminating character '\0'
}

if (status == nullptr) {
value = std::move(out);
}

return status;
}

template <class T>
static OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
Expand Down Expand Up @@ -398,7 +417,7 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
};
}

#if ORT_API_VERSION > 15
#if ORT_API_VERSION > 16
template <typename... Args>
void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
OrtCustomOp::CreateKernel = nullptr;
Expand Down Expand Up @@ -450,7 +469,7 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
};
}
#endif // ORT_API_VERSION > 15
#endif // ORT_API_VERSION > 16

OrtLiteCustomStructV2(const char* op_name,
const char* execution_provider,
Expand All @@ -459,11 +478,11 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {

ParseArgs(&CustomOpKernel::Compute);

#if ORT_API_VERSION > 15
#if ORT_API_VERSION > 16
if (OrtCustomOp::version > 15) {
DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
} else
#endif // ORT_API_VERSION > 15
#endif // ORT_API_VERSION > 16

{
DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
Expand Down
2 changes: 1 addition & 1 deletion operators/audio/audio.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
[]() { return nullptr; }
#ifdef ENABLE_DR_LIBS
,
CustomCpuStruct("AudioDecoder", AudioDecoder)
CustomCpuStructV2("AudioDecoder", AudioDecoder)
#endif
);

Expand Down
48 changes: 32 additions & 16 deletions operators/audio/audio_decoder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
#include "string_tensor.h"
#include "sampling.h"

struct AudioDecoder : public BaseKernel {
struct AudioDecoder{
public:
AudioDecoder(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info),
downsample_rate_(TryToGetAttributeWithDefault<int64_t>("downsampling_rate", 0)),
stereo_mixer_(TryToGetAttributeWithDefault<int64_t>("stereo_to_mono", 0)) {

OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
if (!status) {
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
}

return status;
}

enum class AudioStreamType {
Expand All @@ -38,7 +42,7 @@ struct AudioDecoder : public BaseKernel {
kFLAC
};

AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format) const {
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtStatusPtr& status) const {
static const std::map<std::string, AudioStreamType> format_mapping = {
{"default", AudioStreamType::kDefault},
{"wav", AudioStreamType::kWAV},
Expand All @@ -49,9 +53,11 @@ struct AudioDecoder : public BaseKernel {
if (str_format.length() > 0) {
auto pos = format_mapping.find(str_format);
if (pos == format_mapping.end()) {
ORTX_CXX_API_THROW(MakeString(
"[AudioDecoder]: Unknown audio stream format: ", str_format),
status = OrtW::CreateStatus(MakeString(
"[AudioDecoder]: Unknown audio stream format: ", str_format)
.c_str(),
ORT_INVALID_ARGUMENT);
return stream_format;
}
stream_format = pos->second;
}
Expand All @@ -68,7 +74,7 @@ struct AudioDecoder : public BaseKernel {
// only detect the 8 + 3 bits sync word
stream_format = AudioStreamType::kMP3;
} else {
ORTX_CXX_API_THROW("[AudioDecoder]: Cannot detect audio stream format", ORT_INVALID_ARGUMENT);
status = OrtW::CreateStatus("[AudioDecoder]: Cannot detect audio stream format", ORT_INVALID_ARGUMENT);
}
}

Expand Down Expand Up @@ -96,20 +102,25 @@ struct AudioDecoder : public BaseKernel {
return total_buf_size;
}

void Compute(const ortc::Tensor<uint8_t>& input,
OrtStatusPtr Compute(const ortc::Tensor<uint8_t>& input,
const std::optional<std::string> format,
ortc::Tensor<float>& output0) const {
const uint8_t* p_data = input.Data();
auto input_dim = input.Shape();
OrtStatusPtr status = nullptr;
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
ORTX_CXX_API_THROW("[AudioDecoder]: Expect input dimension [n] or [1,n].", ORT_INVALID_ARGUMENT);
status = OrtW::CreateStatus("[AudioDecoder]: Expect input dimension [n] or [1,n].", ORT_INVALID_ARGUMENT);
return status;
}

std::string str_format;
if (format) {
str_format = *format;
}
auto stream_format = ReadStreamFormat(p_data, str_format);
auto stream_format = ReadStreamFormat(p_data, str_format, status);
if (status) {
return status;
}

int64_t total_buf_size = 0;
std::list<std::vector<float>> lst_frames;
Expand All @@ -119,7 +130,8 @@ struct AudioDecoder : public BaseKernel {
if (stream_format == AudioStreamType::kMP3) {
auto mp3_obj_ptr = std::make_unique<drmp3>();
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = mp3_obj_ptr->sampleRate;
orig_channels = mp3_obj_ptr->channels;
Expand All @@ -129,7 +141,8 @@ struct AudioDecoder : public BaseKernel {
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
if (flac_obj == nullptr) {
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = flac_obj->sampleRate;
orig_channels = flac_obj->channels;
Expand All @@ -138,7 +151,8 @@ struct AudioDecoder : public BaseKernel {
} else {
drwav wav_obj;
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = wav_obj.sampleRate;
orig_channels = wav_obj.channels;
Expand All @@ -147,7 +161,8 @@ struct AudioDecoder : public BaseKernel {

if (downsample_rate_ != 0 &&
orig_sample_rate < downsample_rate_) {
ORTX_CXX_API_THROW("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
status = OrtW::CreateStatus("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
return status;
}

// join all frames
Expand Down Expand Up @@ -182,6 +197,7 @@ struct AudioDecoder : public BaseKernel {
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
float* p_output = output0.Allocate(dim_out);
std::copy(buf.begin(), buf.end(), p_output);
return status;
}

private:
Expand Down
12 changes: 8 additions & 4 deletions operators/text/masked_fill.cc
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "masked_fill.hpp"
#include "string_functions.h"
#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>
#include <algorithm>

void masked_fill(const ortc::Tensor<std::string>& input,
OrtStatusPtr masked_fill(const ortc::Tensor<std::string>& input,
const ortc::Tensor<bool>& input_mask,
ortc::Tensor<std::string>& output) {
OrtStatusPtr status = nullptr;
auto& value_dimensions = input.Shape();
auto& mask_dimensions = input_mask.Shape();
if (!(value_dimensions.empty() || mask_dimensions.size() == 1)) {
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
status = OrtW::CreateStatus("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
return status;
}

if (value_dimensions != mask_dimensions) {
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
status = OrtW::CreateStatus("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
return status;
}

auto& value = input.Data();
Expand All @@ -36,4 +39,5 @@ void masked_fill(const ortc::Tensor<std::string>& input,
}
result_dimension.push_back(result.size());
output.SetStringOutput(result, result_dimension);
return status;
}
12 changes: 0 additions & 12 deletions operators/text/masked_fill.hpp

This file was deleted.

16 changes: 0 additions & 16 deletions operators/text/op_equal.cc

This file was deleted.

15 changes: 0 additions & 15 deletions operators/text/op_equal.hpp

This file was deleted.

Loading

0 comments on commit c71e2ae

Please sign in to comment.