Skip to content

Commit

Permalink
Merge pull request #12 from IITH-Compilers/mlbridge-test
Browse files Browse the repository at this point in the history
Added grpc and onnx tests
  • Loading branch information
svkeerthy authored Feb 19, 2024
2 parents d03a012 + d3dcda3 commit fe54fa3
Show file tree
Hide file tree
Showing 21 changed files with 430 additions and 165 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,5 @@ else()
endif(LLVM_MLBRIDGE)

install(DIRECTORY include/ DESTINATION include)
install(DIRECTORY CompilerInterface DESTINATION include/python/MLCompilerBridge)
install(DIRECTORY CompilerInterface DESTINATION MLModelRunner/CompilerInterface)
file(COPY ${CMAKE_CURRENT_SOURCE_DIR}/CompilerInterface DESTINATION ${CMAKE_BINARY_DIR}/MLModelRunner/)
4 changes: 2 additions & 2 deletions CompilerInterface/GrpcCompilerInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def start_server(self):
"{}:{}".format(self.host, self.server_port)
)

if str(added_port) == self.server_port:
if added_port == self.server_port:
server.start()
print("Server Running")
server.wait_for_termination()
Expand All @@ -100,7 +100,7 @@ def start_server(self):
retries += 1
print(
"The port",
self.port,
self.server_port,
"is already in use retrying! attempt: ",
retries,
)
Expand Down
2 changes: 1 addition & 1 deletion MLModelRunner/gRPCModelRunner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ if(LLVM_MLBRIDGE)
${proto_python_srcs_list}
)
else()
add_library(gRPCModelRunnerLib OBJECT
add_library(gRPCModelRunnerLib
${cc_files}
${proto_srcs_list}
${grpc_srcs_list}
Expand Down
13 changes: 13 additions & 0 deletions SerDes/protobufSerDes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ void *ProtobufSerDes::deserializeUntyped(void *data) {
this->MessageLength = ref.size() * sizeof(int32_t);
return ret->data();
}
if (field->type() == FieldDescriptor::Type::TYPE_INT64) {
auto &ref = reflection->GetRepeatedField<int64_t>(*Response, field);
std::vector<int64_t> *ret =
new std::vector<int64_t>(ref.begin(), ref.end());
this->MessageLength = ref.size() * sizeof(int64_t);
return ret->data();
}
if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) {
auto ref = reflection->GetRepeatedField<float>(*Response, field);
std::vector<float> *ret = new std::vector<float>(ref.begin(), ref.end());
Expand Down Expand Up @@ -199,6 +206,12 @@ void *ProtobufSerDes::deserializeUntyped(void *data) {
this->MessageLength = sizeof(int32_t);
return ptr;
}
if (field->type() == FieldDescriptor::Type::TYPE_INT64) {
int64_t value = reflection->GetInt64(*Response, field);
int64_t *ptr = new int64_t(value);
this->MessageLength = sizeof(int64_t);
return ptr;
}
if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) {
float value = reflection->GetFloat(*Response, field);
float *ptr = new float(value);
Expand Down
5 changes: 3 additions & 2 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ file(GLOB MODEL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/tf_models/*.o)
foreach(MODEL_OBJECT ${MODEL_OBJECTS})
target_link_libraries(MLBridgeCPPTest PRIVATE ${MODEL_OBJECT})
endforeach()
target_link_libraries(MLBridgeCPPTest PRIVATE ModelRunnerUtils)
target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include)
target_link_libraries(MLBridgeCPPTest PRIVATE MLCompilerBridge )
target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_link_libraries(MLBridgeCPPTest PRIVATE tf_xla_runtime)
213 changes: 149 additions & 64 deletions test/MLBridgeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,18 @@
//
//===----------------------------------------------------------------------===//

#include "HelloMLBridge_Env.h"
#include "MLModelRunner/MLModelRunner.h"
#include "MLModelRunner/ONNXModelRunner/ONNXModelRunner.h"
#include "MLModelRunner/PipeModelRunner.h"
#include "MLModelRunner/TFModelRunner.h"
#include "MLModelRunner/Utils/DataTypes.h"
#include "MLModelRunner/Utils/MLConfig.h"
#include "MLModelRunner/gRPCModelRunner.h"
// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.grpc.pb.h"
// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.pb.h"
#include "grpcpp/impl/codegen/status.h"
#include "inference/HelloMLBridge_Env.h"
#include "ProtosInclude.h"
#include "llvm/Support/CommandLine.h"
// #include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <fstream>
#include <google/protobuf/text_format.h>
#include <iostream>
#include <iterator>
#include <memory>
Expand All @@ -31,6 +27,20 @@
#define debug_out \
if (!silent) \
std::cout
using namespace grpc;

#define gRPCModelRunnerInit(datatype) \
increment_port(1); \
MLBridgeTestgRPC_##datatype::Reply response; \
MLBridgeTestgRPC_##datatype::Request request; \
MLRunner = std::make_unique< \
gRPCModelRunner<MLBridgeTestgRPC_##datatype::MLBridgeTestService, \
MLBridgeTestgRPC_##datatype::MLBridgeTestService::Stub, \
MLBridgeTestgRPC_##datatype::Request, \
MLBridgeTestgRPC_##datatype::Reply>>( \
server_address, &request, &response, nullptr); \
MLRunner->setRequest(&request); \
MLRunner->setResponse(&response)

static llvm::cl::opt<std::string>
cl_server_address("test-server-address", llvm::cl::Hidden,
Expand All @@ -41,6 +51,10 @@ static llvm::cl::opt<std::string>
cl_pipe_name("test-pipe-name", llvm::cl::Hidden, llvm::cl::init(""),
llvm::cl::desc("Name for pipe file"));

static llvm::cl::opt<std::string>
cl_onnx_path("onnx-model-path", llvm::cl::Hidden, llvm::cl::init(""),
llvm::cl::desc("Path to onnx model"));

static llvm::cl::opt<std::string> cl_test_config(
"test-config", llvm::cl::Hidden,
llvm::cl::desc("Method for communication with python model"));
Expand All @@ -55,9 +69,9 @@ std::string basename;
BaseSerDes::Kind SerDesType;

std::string test_config;
std::string data_format;
std::string pipe_name;
std::string server_address;
std::string onnx_path;

// send value of type T1. Test received value of type T2 against expected value
template <typename T1, typename T2>
Expand Down Expand Up @@ -96,77 +110,149 @@ void testVector(std::string label, std::vector<T1> value,
debug_out << "\n";
}

void runTests() {
if (data_format != "json") {
testPrimitive<int, int>("int", 11, 12);
testPrimitive<long, long>("long", 1234567890, 1234567891);
testPrimitive<float, float>("float", 3.14, 4.14);
testPrimitive<double, double>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, int>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, long>("vec_long", {123456780, 222, 333},
{123456780, 123456781, 123456782});
testVector<float, float>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, double>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
} else if (data_format == "json") {
testPrimitive<int, IntegerType>("int", 11, 12);
testPrimitive<long, IntegerType>("long", 1234567890, 12345);
testPrimitive<float, RealType>("float", 3.14, 4.14);
testPrimitive<double, RealType>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, IntegerType>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, IntegerType>("vec_long", {123456780, 222, 333},
{6780, 6781, 6782});
testVector<float, RealType>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, RealType>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
int testPipeBytes() {
if (pipe_name == "") {
std::cerr
<< "Pipe name must be specified via --test-pipe-name=<filename>\n";
exit(1);
}
basename = "./" + pipe_name;
SerDesType = BaseSerDes::Kind::Bitstream;
MLRunner = std::make_unique<PipeModelRunner>(
basename + ".out", basename + ".in", SerDesType, nullptr);
testPrimitive<int, int>("int", 11, 12);
testPrimitive<long, long>("long", 1234567890, 1234567891);
testPrimitive<float, float>("float", 3.14, 4.14);
testPrimitive<double, double>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, int>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, long>("vec_long", {123456780, 222, 333},
{123456780, 123456781, 123456782});
testVector<float, float>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, double>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
return 0;
}

int testPipes() {
int testPipeJSON() {
if (pipe_name == "") {
std::cerr
<< "Pipe name must be specified via --test-pipe-name=<filename>\n";
exit(1);
}
basename = "/tmp/" + pipe_name;
if (data_format == "json")
SerDesType = BaseSerDes::Kind::Json;
else if (data_format == "protobuf")
SerDesType = BaseSerDes::Kind::Protobuf;
else if (data_format == "bytes")
SerDesType = BaseSerDes::Kind::Bitstream;
else {
std::cout << "Invalid data format\n";
exit(1);
}

basename = "./" + pipe_name;
SerDesType = BaseSerDes::Kind::Json;
MLRunner = std::make_unique<PipeModelRunner>(
basename + ".out", basename + ".in", SerDesType, nullptr);

runTests();
testPrimitive<int, IntegerType>("int", 11, 12);
testPrimitive<long, IntegerType>("long", 1234567890, 12345);
testPrimitive<float, RealType>("float", 3.14, 4.14);
testPrimitive<double, RealType>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, IntegerType>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, IntegerType>("vec_long", {123456780, 222, 333},
{6780, 6781, 6782});
testVector<float, RealType>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, RealType>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
return 0;
}

void increment_port(int delta) {
int split = server_address.find(":");
int port = stoi(server_address.substr(split + 1));
server_address =
server_address.substr(0, split) + ":" + to_string(port + delta);
}

int testGRPC() {
if (server_address == "") {
std::cerr << "Server Address must be specified via "
"--test-server-address=<ip>:<port>\n";
"--test-server-address=\"<ip>:<port>\"\n";
exit(1);
}
{
gRPCModelRunnerInit(int);
testPrimitive<int, int>("int", 11, 12);
}
{
gRPCModelRunnerInit(long);
testPrimitive<long, long>("long", 1234567890, 1234567891);
}
{
gRPCModelRunnerInit(float);
testPrimitive<float, float>("float", 3.14, 4.14);
}
{
gRPCModelRunnerInit(double);
testPrimitive<double, double>("double", 0.123456789123456789,
1.123456789123456789);
}
increment_port(1);
{
gRPCModelRunnerInit(bool);
testPrimitive<bool, bool>("bool", true, false);
}
{
gRPCModelRunnerInit(vec_int);
testVector<int, int>("vec_int", {11, 22, 33}, {12, 23, 34});
}
{
gRPCModelRunnerInit(vec_long);
testVector<long, long>("vec_long", {123456780, 222, 333},
{123456780, 123456781, 123456782});
}
{
gRPCModelRunnerInit(vec_float);
testVector<float, float>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
}
{
gRPCModelRunnerInit(vec_double);
testVector<double, double>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
}
return 0;
}

int testONNX() { return 0; }
class ONNXTest : public MLBridgeTestEnv {
public:
int run(int expectedAction) {
onnx_path = cl_onnx_path.getValue();
if (onnx_path == "") {
std::cerr << "ONNX model path must be specified via "
"--onnx-model-path=<filepath>\n";
exit(1);
}
FeatureVector.clear();
int n = 100;
for (int i = 0; i < n; i++) {
float delta = (float)(i - expectedAction) / n;
FeatureVector.push_back(delta * delta);
}

Agent *agent = new Agent(onnx_path);
std::map<std::string, Agent *> agents;
agents["agent"] = agent;
MLRunner = std::make_unique<ONNXModelRunner>(this, agents, nullptr);
MLRunner->evaluate<int>();
if (lastAction != expectedAction) {
std::cerr << "Error: Expected action: " << expectedAction
<< ", Computed action: " << lastAction << "\n";
exit(1);
}
return 0;
}
};

} // namespace

Expand All @@ -176,18 +262,17 @@ int main(int argc, char **argv) {

if (test_config == "pipe-bytes") {
pipe_name = cl_pipe_name.getValue();
data_format = "bytes";
testPipes();
testPipeBytes();
} else if (test_config == "pipe-json") {
pipe_name = cl_pipe_name.getValue();
data_format = "json";
testPipes();
testPipeJSON();
} else if (test_config == "grpc") {
server_address = cl_server_address.getValue();
testGRPC();
} else if (test_config == "onnx")
testONNX();
else
} else if (test_config == "onnx") {
ONNXTest t;
t.run(20);
} else
std::cerr << "--test-config must be provided from [pipe-bytes, pipe-json, "
"grpc, onnx]\n";
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,29 @@
#include "llvm/Support/raw_ostream.h"

using namespace MLBridge;
class HelloMLBridgeEnv : public Environment {
class MLBridgeTestEnv : public Environment {
Observation CurrObs;

public:
HelloMLBridgeEnv() { setNextAgent("agent"); };
MLBridgeTestEnv() { setNextAgent("agent"); };
Observation &reset() override;
Observation &step(Action) override;
Action lastAction;

protected:
std::vector<float> FeatureVector;
};

Observation &HelloMLBridgeEnv::step(Action Action) {
Observation &MLBridgeTestEnv::step(Action Action) {
CurrObs.clear();
std::copy(FeatureVector.begin(), FeatureVector.end(),
std::back_inserter(CurrObs));
llvm::outs() << "Action: " << Action << "\n";
lastAction = Action;
setDone();
return CurrObs;
}

Observation &HelloMLBridgeEnv::reset() {
Observation &MLBridgeTestEnv::reset() {
std::copy(FeatureVector.begin(), FeatureVector.end(),
std::back_inserter(CurrObs));
return CurrObs;
Expand Down
Loading

0 comments on commit fe54fa3

Please sign in to comment.