Skip to content

Commit

Permalink
Added test for pipe communication
Browse files Browse the repository at this point in the history
  • Loading branch information
RajivChitale committed Jan 13, 2024
1 parent dc718b3 commit ebf37ec
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 16 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ set(protobuf_MODULE_COMPATIBLE TRUE)
find_package(Protobuf CONFIG REQUIRED)

include_directories(${Protobuf_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/include)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti -fPIC")
set (CMAKE_CXX_STANDARD 17)

option(LLVM_MLBRIDGE "MLCompilerBridge install for LLVM" OFF)
Expand All @@ -22,6 +22,7 @@ endif()

add_subdirectory(MLModelRunner)
add_subdirectory(SerDes)
add_subdirectory(test)

if(LLVM_MLBRIDGE)
include(AddLLVM)
Expand Down Expand Up @@ -49,6 +50,11 @@ else()
add_library(MLCompilerBridge STATIC tools.cpp)
target_link_libraries(MLCompilerBridge PUBLIC SerDesLib ModelRunnerLib ONNXModelRunnerLib LLVM-10 ${llvm_libs})

add_executable(MLCompilerBridgeTest $<TARGET_OBJECTS:LLVMMLBridgeTest>)
# add_library(MLCompilerBridgeTest SHARED $<TARGET_OBJECTS:LLVMHelloMLBridgeTest>)
target_link_libraries(MLCompilerBridgeTest PUBLIC MLCompilerBridge)
set_property(TARGET MLCompilerBridge PROPERTY POSITION_INDEPENDENT_CODE 1)

add_library(MLCompilerBridgeC STATIC $<TARGET_OBJECTS:ModelRunnerCWrapper>)
target_link_libraries(MLCompilerBridgeC PUBLIC SerDesCLib ModelRunnerCLib ONNXModelRunnerLib LLVM-10 ${llvm_libs})
target_include_directories(MLCompilerBridgeC PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include ${LLVM_INCLUDE_DIRS})
Expand Down
8 changes: 5 additions & 3 deletions MLModelRunner/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ if(LLVM_MLBRIDGE)
MLConfig.cpp
)
target_include_directories(ModelRunnerUtils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include)
endif(LLVM_MLBRIDGE)

add_library(ModelRunnerCUtils OBJECT MLConfig.cpp)

else()
add_library(ModelRunnerCUtils OBJECT MLConfig.cpp)
add_library(ModelRunnerUtils OBJECT MLConfig.cpp)
endif(LLVM_MLBRIDGE)
16 changes: 8 additions & 8 deletions SerDes/jsonSerDes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ void *JsonSerDes::desJson(json::Value *V) {
return nullptr;
case json::Value::Kind::Number: {
if (auto x = V->getAsInteger()) {
int *ret = new int();
long *ret = new long();
*ret = x.value();
this->MessageLength = sizeof(int);
this->MessageLength = sizeof(long);
return ret;
} else if (auto x = V->getAsNumber()) {
float *ret = new float();
double *ret = new double();
*ret = x.value();
this->MessageLength = sizeof(float);
this->MessageLength = sizeof(double);
return ret;
} else {
llvm::errs() << "Error in desJson: Number is not int, or double\n";
Expand Down Expand Up @@ -78,18 +78,18 @@ void *JsonSerDes::desJson(json::Value *V) {
switch (first->kind()) {
case json::Value::Kind::Number: {
if (auto x = first->getAsInteger()) {
std::vector<int> *ret = new std::vector<int>();
std::vector<long> *ret = new std::vector<long>();
for (auto it : *arr) {
ret->push_back(it.getAsInteger().value());
}
this->MessageLength = ret->size() * sizeof(int);
this->MessageLength = ret->size() * sizeof(long);
return ret->data();
} else if (auto x = first->getAsNumber()) {
std::vector<float> *ret = new std::vector<float>();
std::vector<double> *ret = new std::vector<double>();
for (auto it : *arr) {
ret->push_back(it.getAsNumber().value());
}
this->MessageLength = ret->size() * sizeof(float);
this->MessageLength = ret->size() * sizeof(double);
return ret->data();
} else {
llvm::errs() << "Error in desJson: Number is not int, or double\n";
Expand Down
19 changes: 15 additions & 4 deletions python-interface/SerDes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import json
import log_reader
import struct


import ctypes
class SerDes:
def __init__(self, data_format, pipe_name):
self.data_format = data_format
Expand Down Expand Up @@ -81,7 +80,7 @@ def sendData(self, data):
self.tc.flush()

def serializeDataJson(self, data):
msg = json.dumps({"out": data}).encode("utf-8")
msg = json.dumps({"out": data}, cls=NpEncoder).encode("utf-8")
hdr = len(msg).to_bytes(8, "little")
out = hdr + msg
return out
Expand All @@ -93,7 +92,11 @@ def _pack(data):
elif isinstance(data, float):
return struct.pack("f", data)
elif isinstance(data, str) and len(data) == 1:
return struct.pack("c", data)
return struct.pack('c', data)
elif isinstance(data, ctypes.c_double):
return struct.pack('d', data.value)
elif isinstance(data, ctypes.c_long):
return struct.pack('l', data.value)
elif isinstance(data, list):
return b"".join([_pack(x) for x in data])

Expand All @@ -104,3 +107,11 @@ def _pack(data):

def serializeDataProtobuf(self, data):
raise NotImplementedError

class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, ctypes.c_long):
return obj.value
if isinstance(obj, ctypes.c_double):
return obj.value
return super(NpEncoder, self).default(obj)
176 changes: 176 additions & 0 deletions python-interface/mlbridge-test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import log_reader
import argparse
import os, io, json
import SerDes
import ctypes
import numpy as np

import sys
import torch, torch.nn as nn

sys.path.append(
"../MLModelRunner/gRPCModelRunner/Python-Utilities"
)
import helloMLBridge_pb2, helloMLBridge_pb2_grpc, grpc
from concurrent import futures

parser = argparse.ArgumentParser()
parser.add_argument("--use_pipe", type=bool, default=False, help="Use pipe or not", required=False)
parser.add_argument(
"--data_format",
type=str,
choices=["json", "protobuf", "bytes"],
help="Data format to use for communication",
)
parser.add_argument(
"--pipe_name",
type=str,
help="Pipe Name",
)
parser.add_argument(
"--use_grpc",
action="store_true",
help="Use grpc communication",
required=False,
default=False,
)
parser.add_argument(
"--server_port",
type=int,
help="Server Port",
default=5050,
)
args = parser.parse_args()

class DummyModel(nn.Module):
def __init__(self, input_dim=10):
nn.Module.__init__(self)
self.fc1 = nn.Linear(input_dim, 1)

def forward(self, input):
x = self.fc1(input)
return x


expected_type = {
1: 'int',
2: 'long',
3: 'float',
4: 'double',
5: 'char',
6: 'bool',
7: 'vec_int',
8: 'vec_long',
9: 'vec_float',
10: 'vec_double'
}

expected_data = {
1: 11,
2: 1234567890,
3: 3.14,
4: 0.123456789123456789,
5: ord('a'),
6: True,
7: [11,22,33],
8: [111,222,333],
9: [11.1,22.2,33.3],
10: [-1.1111111111,-2.2222222222,-3.3333333333],
}

returned_data = {
1: 12,
2: ctypes.c_long(1234567891),
3: 4.14,
4: ctypes.c_double(1.123456789123456789),
5: ord('b'),
6: False,
7: [12,23,34],
8: [ctypes.c_long(123456780),ctypes.c_long(123456781),ctypes.c_long(123456782)],
9: [1.11,2.22,-3.33,0],
10: [ctypes.c_double(1.12345678912345670), ctypes.c_double(-1.12345678912345671)]
}

def run_pipe_communication(data_format, pipe_name):
serdes = SerDes.SerDes(data_format, "/tmp/" + pipe_name)
print('Serdes init...')
serdes.init()
i = 0
while True:
i += 1
try:
data = serdes.readObservation()
if data_format == "json":
key = list(data)[0]
data = data[key]
elif data_format == "bytes":
data = [x for x in data[0]]
if len(data) == 1:
data = data[0]

print(expected_type[i], "request:", data)

if isinstance(expected_data[i], list):
for e,d in zip(expected_data[i],data):
if abs(e-d)>10e-6:
print(f"Mismatch in {expected_type[i]}")
# raise Exception(f"Mismatch in {expected_type[i]}")

elif abs(data - expected_data[i]) >10e-6:
print(f"Mismatch in {expected_type[i]}")
# raise Exception(f"Mismatch in {expected_type[i]}")

serdes.sendData(returned_data[i])
except Exception as e:
print("*******Exception*******", e)
serdes.init()

class service_server(helloMLBridge_pb2_grpc.HelloMLBridgeService):
def __init__(self, data_format, pipe_name):
# self.serdes = SerDes.SerDes(data_format, pipe_name)
# self.serdes.init()
pass
def getAdvice(self, request, context):
try:
print(request)
print("Entered getAdvice")
print("Data: ", request.tensor)
reply = helloMLBridge_pb2.ActionRequest(action=1)
return reply
except:
reply = helloMLBridge_pb2.ActionRequest(action=-1)
return reply

def test_func():
data = 3.24
import struct
print(data, type(data))
byte_data = struct.pack('f', data)
print(byte_data, len(byte_data))


print('decoding...')
decoded = float(byte_data)

print(decoded, type(decoded))

if __name__ == "__main__":
# test_func()
# exit(0)
if args.use_pipe:
run_pipe_communication(args.data_format, args.pipe_name)
elif args.use_grpc:
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=20),
options=[
("grpc.max_send_message_length", 200 * 1024 * 1024),
("grpc.max_receive_message_length", 200 * 1024 * 1024),
],
)
helloMLBridge_pb2_grpc.add_HelloMLBridgeServiceServicer_to_server(
service_server(args.data_format, args.pipe_name), server
)
server.add_insecure_port(f"localhost:{args.server_port}")
server.start()
print("Server Running")
server.wait_for_termination()
15 changes: 15 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
find_package(LLVM 10.0.1 REQUIRED CONFIG)
include_directories(SYSTEM ${LLVM_INCLUDE_DIRS})

# if(NOT LLVM_ENABLE_RTTI)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti -fPIC")
# endif()

add_library(LLVMMLBridgeTest OBJECT MLBridgeTest.cpp)
file(GLOB MODEL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/tf_models/*.o)

foreach(MODEL_OBJECT ${MODEL_OBJECTS})
target_link_libraries(LLVMMLBridgeTest PRIVATE ${MODEL_OBJECT})
endforeach()
target_link_libraries(LLVMMLBridgeTest PRIVATE ModelRunnerUtils)
target_include_directories(LLVMMLBridgeTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include)
Loading

0 comments on commit ebf37ec

Please sign in to comment.