Skip to content

Commit

Permalink
some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinagrl committed Jan 10, 2025
1 parent c0e8ccc commit f5b2070
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 42 deletions.
2 changes: 1 addition & 1 deletion MLModelRunner/gRPCModelRunner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ endif()
# Find gRPC installation
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.

find_package(gRPC 1.66.0 EXACT CONFIG REQUIRED)
find_package(gRPC 1.34.0 EXACT CONFIG REQUIRED)
message(STATUS "Using gRPC ${gRPC_VERSION}")

set(_GRPC_GRPCPP gRPC::grpc++)
Expand Down
22 changes: 5 additions & 17 deletions include/MLModelRunner/PTModelRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,12 @@ class PTModelRunner final : public MLModelRunner {


void* PTModelRunner::evaluateUntyped() {
auto outputs = SerDes->CompiledModel->run((*(this->SerDes->inputTensors)));
for (auto i = outputs.begin(); i != outputs.end(); ++i)
(*(this->SerDes->outputTensors)).push_back(*i);
void* rawData = SerDes->deserializeUntyped(SerDes->outputTensors);
return rawData;

if ((*(this->SerDes->inputTensors)).empty()) {
llvm::errs() << "Input vector is empty.\n";
return nullptr;
}

try {
// Run the model with the input tensors
auto outputs = SerDes->CompiledModel->run((*(this->SerDes->inputTensors)));

//Store the above output in the outputTensors, outputTensors is a pointer to the vector of tensors, already initialized in the constructor
for (auto i = outputs.begin(); i != outputs.end(); ++i)
(*(this->SerDes->outputTensors)).push_back(*i);

// Convert to raw data format using deserializeUntyped
void* rawData = SerDes->deserializeUntyped(SerDes->outputTensors);

return rawData;
} catch (const c10::Error& e) {
llvm::errs() << "Error during model evaluation: " << e.what() << "\n";
return nullptr;
Expand Down
95 changes: 71 additions & 24 deletions test/MLBridgeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,100 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "MLModelRunner/MLModelRunner.h"
#include "MLModelRunner/PTModelRunner.h"
// #include "MLModelRunner/TFModelRunner.h"
#include "SerDes/pytorchSerDes.h"
#include "llvm/Support/CommandLine.h"
#include <algorithm>
#include <chrono>
#include <fstream>
#include <iostream>
#include <iterator>
#include <memory>
#include <random>
#include <string>
#include <iostream>
#include <chrono>
#include <vector>

using namespace MLBridge;
using namespace std;

int main(int argc, char **argv) {

printf("Starting PTModelRunner test...\n");
llvm::LLVMContext Ctx;
int main(int argc, char **argv)
{
if (argc != 2)
{
cerr << "Usage: " << argv[0] << " <size>" << endl;
return 1;
}

string modelPath = "model.so";
unique_ptr<PTModelRunner> runner = make_unique<PTModelRunner>(modelPath, Ctx);
int sz = stoi(argv[1]);
cout << "Running inference for size: " << sz << endl;

// Create random input data of size sz
vector<float> inputData(sz);
random_device rd;
mt19937 gen(rd());
uniform_real_distribution<float> dis(0.0, 1.0);
generate(inputData.begin(), inputData.end(), [&]()
{ return dis(gen); });

llvm::LLVMContext Ctx;
string modelPath = "models_pt/model_" + to_string(sz) + ".so";

// Create some input data
vector<float> inputData(10, 0.5f); // Example input with size 10
try
{
// Start measuring time just before loading the model
auto start = chrono::high_resolution_clock::now();

unique_ptr<PTModelRunner> runner = make_unique<PTModelRunner>(modelPath, Ctx);

// Prepare the input
basic_string<char> input_str = "input";
pair< string, vector<float>& > inputPair = make_pair(input_str, ref(inputData));

pair<string, vector<float> &> inputPair = make_pair(input_str, ref(inputData));
runner->populateFeatures(inputPair);

void* result = runner->evaluateUntyped();
// Perform inference
void *result = runner->evaluateUntyped();

if (result) {
cout << "Model evaluation succeeded." << endl;
vector<float> *output = reinterpret_cast<vector<float> *>(result);
cout << "Output: ";
for (auto &v : *output) {
cout << v << " ";
}
cout << endl;
auto end = chrono::high_resolution_clock::now();
chrono::duration<double, milli> inferenceTime = end - start;

} else {
cerr << "Model evaluation failed." << endl;
// Check result
if (result)
{
vector<float> *output = reinterpret_cast<vector<float> *>(result);
cout << "Output: ";
for (auto &v : *output)
{
cout << v << " ";
}
cout << endl;
}
else
{
cerr << "Model evaluation failed." << endl;
return 1;
}

// // Log the results
// ofstream resultsFile("results.csv", ios::app);
// if (resultsFile.is_open())
// {
// resultsFile << sz << "," << inferenceTime.count() << endl;
// resultsFile.close();
// }
// else
// {
// cerr << "Failed to open results.csv" << endl;
// }

cout << "Inference completed in: " << inferenceTime.count() << " ms" << endl;
}
catch (const exception &ex)
{
cerr << "Error: " << ex.what() << endl;
return 1;
}

return 0;
}

0 comments on commit f5b2070

Please sign in to comment.