Skip to content

Commit

Permalink
Add CUDA build support and some code refinements (#581)
Browse files Browse the repository at this point in the history
* the cuda kernel first example

* Update test_ortops.cc

* revert some unneccesary changes

* unix-like os build failure

* refactor header files

* fix python dll exporting error.
  • Loading branch information
wenbingl authored Oct 31, 2023
1 parent 2aeca72 commit a0c2625
Show file tree
Hide file tree
Showing 16 changed files with 464 additions and 309 deletions.
35 changes: 28 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ option(OCOS_BUILD_JAVA "Enable building the Java package" OFF)
option(OCOS_BUILD_ANDROID "Enable building the Android package" OFF)
option(OCOS_BUILD_APPLE_FRAMEWORK "Enable building of the MacOS/iOS framework" OFF)

option(OCOS_USE_CUDA "Build the CUDA kernels" OFF)

# Optional value. Some operators do not support old versions due to using the new custom operator interface
# and will be disabled if this value is set and the version is incompatible.
set(OCOS_ONNXRUNTIME_VERSION "" CACHE STRING
Expand Down Expand Up @@ -149,11 +151,11 @@ if (MSVC)

if (CMAKE_MSVC_RUNTIME_LIBRARY STREQUAL "MultiThreaded" OR
CMAKE_MSVC_RUNTIME_LIBRARY STREQUAL "MultiThreadedDebug")
set(OCOS_STATIC_MSVC_RUNTIME_LIBRARY ON)
set(_STATIC_MSVC_RUNTIME_LIBRARY ON)
else()
set(OCOS_STATIC_MSVC_RUNTIME_LIBRARY OFF)
set(_STATIC_MSVC_RUNTIME_LIBRARY OFF)
endif()
message(STATUS "OCOS_STATIC_MSVC_RUNTIME_LIBRARY: ${OCOS_STATIC_MSVC_RUNTIME_LIBRARY}")
message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}")

endif()

Expand Down Expand Up @@ -273,6 +275,13 @@ macro(standardize_output_folder bin_target)
PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
endmacro()

if(OCOS_USE_CUDA)
include(ext_cuda)
endif()

#######################################################################################################################
# build the operator file list from the build flag.

if(OCOS_ENABLE_RE2_REGEX)
if(NOT TARGET re2::re2)
set(RE2_BUILD_TESTING OFF CACHE INTERNAL "")
Expand Down Expand Up @@ -324,6 +333,11 @@ if(OCOS_ENABLE_MATH)
endif()

file(GLOB TARGET_SRC_MATH "operators/math/*.cc" "operators/math/*.h*")
if(OCOS_USED_CUDA)
file(GLOB TARGET_SRC_MATH_CUDA "operators/math/cuda/*.*")
list(APPEND TARGET_SRC_MATH ${TARGET_SRC_MATH_CUDA})
endif()

list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
endif()

Expand Down Expand Up @@ -415,7 +429,7 @@ if(OCOS_ENABLE_AZURE)
if (APPLE OR
CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR
(MSVC AND CMAKE_GENERATOR_PLATFORM STREQUAL "ARM") OR
(MSVC AND OCOS_STATIC_MSVC_RUNTIME_LIBRARY) OR
(MSVC AND _STATIC_MSVC_RUNTIME_LIBRARY) OR
(ANDROID AND ANDROID_ABI STREQUAL "x86"))
message(STATUS "Excluding Azure custom operators as they are not currently supported in this type of build. ")
set(OCOS_ENABLE_AZURE OFF)
Expand Down Expand Up @@ -491,9 +505,10 @@ if(_HAS_TOKENIZER)
list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC})
endif()

# ### make all compile options.
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
### make all compile options.
if(MSVC)
add_compile_options("$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /utf-8>" "$<$<COMPILE_LANGUAGE:CXX,C>:/utf-8>")
endif()

# Library will not be built if the target src for the lib does not contain any valid *.cc files.
# Hence the placeholders `noexcep_operators_placeholder.cc` and `ocos_operators_placeholder.cc`
Expand Down Expand Up @@ -798,6 +813,12 @@ if(_BUILD_SHARED_LIBRARY)

endif()

if(OCOS_USE_CUDA)
target_include_directories(ortcustomops PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(extensions_shared PUBLIC cudart cublas cufft)
set_target_properties(ocos_operators PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
endif()

if(OCOS_BUILD_PYTHON)
message(STATUS "Python Build is enabled")
include(ext_python)
Expand Down
71 changes: 71 additions & 0 deletions cmake/ext_cuda.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

enable_language(CUDA)

set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
set(CMAKE_CUDA_STANDARD 17)


if(NOT CMAKE_CUDA_ARCHITECTURES)
if(CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu")
# Support for Jetson/Tegra ARM devices
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_53,code=sm_53") # TX1, Nano
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_62,code=sm_62") # TX2
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_72,code=sm_72") # AGX Xavier,
# NX Xavier
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_87,code=sm_87") # AGX Orin,
# NX Orin
endif()
else()
# the following compute capabilities are removed in CUDA 11 Toolkit
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11)
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_30,code=sm_30") # K series
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
# 37, 50 still work in CUDA 11 but are marked deprecated and will be
# removed in future CUDA version.
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_37,code=sm_37") # K80
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_50,code=sm_50") # M series
endif()
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_52,code=sm_52") # M60
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_60,code=sm_60") # P series
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_70,code=sm_70") # V series
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_75,code=sm_75") # T series
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_80,code=sm_80") # A series
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12)
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -gencode=arch=compute_90,code=sm_90") # H series
endif()
endif()
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch")
endif()

if(NOT WIN32)
list(APPEND CUDA_NVCC_FLAGS --compiler-options -fPIC)
endif()

# Options passed to cudafe
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=bad_friend_decl\"")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_compare_with_zero\"")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"")

add_compile_definitions(USE_CUDA)
5 changes: 3 additions & 2 deletions cmake/ext_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ endblock()
if (NOT Python3_FOUND)
message(FATAL_ERROR "Python3 or NumPy not found!")
endif()

file(GLOB TARGET_SRC_PYOPS "pyop/*.cc" "pyop/*.h" "shared/*.cc")
if (WIN32)
list(APPEND shared_TARGET_SRC "${PROJECT_SOURCE_DIR}/pyop/extensions_pydll.def")
list(APPEND TARGET_SRC_PYOPS "pyop/extensions_pydll.def")
endif()

file(GLOB TARGET_SRC_PYOPS "pyop/*.cc" "pyop/*.h" "shared/*.cc")
add_library(extensions_pydll SHARED ${TARGET_SRC_PYOPS} ${shared_TARGET_LIB_SRC})
standardize_output_folder(extensions_pydll)
list(APPEND OCOS_COMPILE_DEFINITIONS PYTHON_OP_SUPPORT)
Expand Down
51 changes: 0 additions & 51 deletions includes/ocos.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,57 +18,6 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);

constexpr const char* c_OpDomain = "ai.onnx.contrib";
constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";

struct BaseKernel {
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
: api_(api), info_(info), ort_(api_) {
}

template <class T>
bool TryToGetAttribute(const char* name, T& value) const noexcept;

template <class T>
T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
T result = default_value;
TryToGetAttribute(name, result);
return result;
}

void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
const std::vector<int64_t>& data);

protected:
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;

const OrtApi& api_;
OrtW::CustomOpApi ort_;
const OrtKernelInfo& info_;
};

struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions() = default;
OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}

int64_t Size() const {
int64_t s = 1;
for (auto it = begin(); it != end(); ++it)
s *= *it;
return s;
}

bool IsScalar() const {
return empty();
}

bool IsVector() const {
return size() == 1;
}
};

template <typename... Args>
class CuopContainer {
public:
Expand Down
Loading

0 comments on commit a0c2625

Please sign in to comment.