Skip to content

Commit

Permalink
Disable quant for remote infer
Browse files Browse the repository at this point in the history
Tracked-On:
Signed-off-by: Ratnesh Kumar Rai <[email protected]>
  • Loading branch information
rairatne committed May 24, 2023
1 parent 4139a4c commit fad628a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 141 deletions.
265 changes: 126 additions & 139 deletions BasePreparedModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ void BasePreparedModel::deinitialize() {
if ((ret_xml != 0) || (ret_bin != 0)) {
ALOGW("%s Deletion status of xml:%d, bin:%d", __func__, ret_xml, ret_bin);
}
auto reply = mDetectionClient->release(is_success);
ALOGI("GRPC release response is %d : %s", is_success, reply.c_str());
setRemoteEnabled(false);

if (mRemoteCheck && mDetectionClient) {
auto reply = mDetectionClient->release(is_success);
ALOGI("GRPC release response is %d : %s", is_success, reply.c_str());
setRemoteEnabled(false);
}
ALOGV("Exiting %s", __func__);
}

Expand All @@ -64,12 +65,10 @@ bool BasePreparedModel::initialize() {
ALOGE("Failed to initialize Model runtime parameters!!");
return false;
}

setRemoteEnabled(checkRemoteConnection());
mNgraphNetCreator = std::make_shared<NgraphNetworkCreator>(mModelInfo, mTargetDevice);

if (!mNgraphNetCreator->validateOperations()) return false;
ALOGI("Generating IR Graph");
ALOGI("Generating IR Graph for Model %u", mFileId);
auto ov_model = mNgraphNetCreator->generateGraph();
if (ov_model == nullptr) {
ALOGE("%s Openvino model generation failed", __func__);
Expand All @@ -78,17 +77,29 @@ bool BasePreparedModel::initialize() {
try {
mPlugin = std::make_unique<IENetwork>(mTargetDevice, ov_model);
mPlugin->loadNetwork(mXmlFile, mBinFile);
if(mRemoteCheck) {
auto resp = loadRemoteModel(mXmlFile, mBinFile);
ALOGD("%s Load Remote Model returns %d", __func__, resp);
} else {
ALOGD("%s Remote connection unavailable", __func__);
}
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
return false;
}

bool disableOffload = false;
for (auto i : mModelInfo->getModelInputIndexes()) {
auto& nnapiOperandType = mModelInfo->getOperand(i).type;
switch (nnapiOperandType) {
case OperandType::FLOAT32:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_INT32:
break;
default :
ALOGD("GRPC Remote Infer not enabled for %d", nnapiOperandType);
disableOffload = true;
break;
}
if (disableOffload) break;
}
if (!disableOffload) {
loadRemoteModel(mXmlFile, mBinFile);
}
ALOGV("Exiting %s", __func__);
return true;
}
Expand All @@ -103,9 +114,11 @@ bool BasePreparedModel::checkRemoteConnection() {
args.SetMaxSendMessageSize(INT_MAX);
mDetectionClient = std::make_shared<DetectionClient>(
grpc::CreateCustomChannel(grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
if(mDetectionClient) {
if (mDetectionClient) {
auto reply = mDetectionClient->prepare(is_success);
ALOGI("GRPC(TCP) prepare response is %d : %s", is_success, reply.c_str());
} else {
ALOGE("%s mDetectionClient is null", __func__);
}
}
if (!is_success && getGrpcSocketPath(grpc_prop)) {
Expand All @@ -115,30 +128,27 @@ bool BasePreparedModel::checkRemoteConnection() {
args.SetMaxSendMessageSize(INT_MAX);
mDetectionClient = std::make_shared<DetectionClient>(
grpc::CreateCustomChannel(std::string("unix:") + grpc_prop, grpc::InsecureChannelCredentials(), args), mFileId);
if(mDetectionClient) {
if (mDetectionClient) {
auto reply = mDetectionClient->prepare(is_success);
ALOGI("GRPC(unix) prepare response is %d : %s", is_success, reply.c_str());
} else {
ALOGE("%s mDetectionClient is null", __func__);
}
}
setRemoteEnabled(is_success);
return is_success;
}

bool BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) {
ALOGI("Entering %s", __func__);
void BasePreparedModel::loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin) {
ALOGI("Entering %s for Model %u", __func__, mFileId);
bool is_success = false;
if(mDetectionClient) {
if(checkRemoteConnection() && mDetectionClient) {
auto reply = mDetectionClient->sendIRs(is_success, ir_xml, ir_bin);
ALOGI("sendIRs response GRPC %d %s", is_success, reply.c_str());
if (reply == "status False") {
ALOGE("%s Model Load Failed",__func__);
}
setRemoteEnabled(is_success);
}
else {
ALOGE("%s mDetectionClient is null",__func__);
}
setRemoteEnabled(is_success);
return is_success;
}

void BasePreparedModel::setRemoteEnabled(bool flag) {
Expand Down Expand Up @@ -287,20 +297,14 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
ALOGD("%s Run", __func__);

if (measure == MeasureTiming::YES) deviceStart = now();
if(preparedModel->mRemoteCheck) {
ALOGI("%s GRPC Remote Infer", __func__);
auto reply = preparedModel->mDetectionClient->remote_infer();
ALOGI("***********GRPC server response************* %s", reply.c_str());
}
if (!preparedModel->mRemoteCheck || !preparedModel->mDetectionClient->get_status()){
try {
plugin->infer();
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
return;
}
try {
plugin->infer();
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
notify(callback, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
return;
}

if (measure == MeasureTiming::YES) deviceEnd = now();

tensorIndex = 0;
Expand Down Expand Up @@ -351,50 +355,45 @@ void asyncExecute(const Request& request, MeasureTiming measure, BasePreparedMod
return;
}

if (preparedModel->mRemoteCheck && preparedModel->mDetectionClient && preparedModel->mDetectionClient->get_status()) {
preparedModel->mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
ngraphNw->getOutputShape(outIndex), expectedLength);
} else {
switch (operandType) {
case OperandType::TENSOR_INT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT32:
std::memcpy((uint8_t*)destPtr, srcTensor.data<float>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_BOOL8:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_SYMM:
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT16:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_SYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
srcTensor.get_byte_size());
break;
default:
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
}
switch (operandType) {
case OperandType::TENSOR_INT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT32:
std::memcpy((uint8_t*)destPtr, srcTensor.data<float>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_BOOL8:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_SYMM:
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT16:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_SYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
srcTensor.get_byte_size());
break;
default:
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
}
}

Expand Down Expand Up @@ -843,19 +842,12 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,

time_point deviceStart, deviceEnd;
if (measure == MeasureTiming::YES) deviceStart = now();
if(mRemoteCheck) {
ALOGI("%s GRPC Remote Infer", __func__);
auto reply = mDetectionClient->remote_infer();
ALOGI("***********GRPC server response************* %s", reply.c_str());
}
if (!mRemoteCheck || !mDetectionClient->get_status()){
try {
mPlugin->infer();
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
cb(V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
return Void();
}
try {
mPlugin->infer();
} catch (const std::exception& ex) {
ALOGE("%s Exception !!! %s", __func__, ex.what());
cb(V1_3::ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
return Void();
}
if (measure == MeasureTiming::YES) deviceEnd = now();

Expand Down Expand Up @@ -893,50 +885,45 @@ Return<void> BasePreparedModel::executeFenced(const V1_3::Request& request1_3,
mModelInfo->updateOutputshapes(i, outDims);
}

if (mRemoteCheck && mDetectionClient && mDetectionClient->get_status()) {
mDetectionClient->get_output_data(std::to_string(i), (uint8_t*)destPtr,
mNgraphNetCreator->getOutputShape(outIndex), expectedLength);
} else {
switch (operandType) {
case OperandType::TENSOR_INT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<float>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_BOOL8:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_SYMM:
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT16:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_SYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
srcTensor.get_byte_size());
break;
default:
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
}
switch (operandType) {
case OperandType::TENSOR_INT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int32_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT32:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<float>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_BOOL8:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<bool>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT8_SYMM:
case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
std::memcpy((int8_t*)destPtr, (int8_t*)srcTensor.data<int8_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_FLOAT16:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<ov::float16>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_SYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<int16_t>(),
srcTensor.get_byte_size());
break;
case OperandType::TENSOR_QUANT16_ASYMM:
std::memcpy((uint8_t*)destPtr, (uint8_t*)srcTensor.data<uint16_t>(),
srcTensor.get_byte_size());
break;
default:
std::memcpy((uint8_t*)destPtr, srcTensor.data<uint8_t>(),
srcTensor.get_byte_size());
break;
}
}

Expand Down
2 changes: 1 addition & 1 deletion BasePreparedModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class BasePreparedModel : public V1_3::IPreparedModel {

virtual bool initialize();
virtual bool checkRemoteConnection();
virtual bool loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin);
virtual void loadRemoteModel(const std::string& ir_xml, const std::string& ir_bin);
virtual void setRemoteEnabled(bool flag);

std::shared_ptr<NnapiModelInfo> getModelInfo() { return mModelInfo; }
Expand Down
2 changes: 1 addition & 1 deletion DetectionClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ std::string DetectionClient::prepare(bool& flag) {
request.mutable_token()->set_data(mToken);
ReplyStatus reply;
ClientContext context;
time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(100);
time_point deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(10000);
context.set_deadline(deadline);

Status status = stub_->prepare(&context, request, &reply);
Expand Down

0 comments on commit fad628a

Please sign in to comment.