From d3deb2e1026b55477c5b972f54e138f7f468ba7a Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Fri, 3 Nov 2023 17:09:27 -0700 Subject: [PATCH] Revert "Revert "Enable ROCm RNN-T Loss (#2485)" (#3586)" This reverts commit 49d3eeca9d8fc6709fc3ad42415fbadfb02c0332. --- .gitmodules | 3 + CMakeLists.txt | 10 ++++ third_party/hipify_torch | 1 + torchaudio/csrc/CMakeLists.txt | 55 ++++++++++++++++++- torchaudio/csrc/rnnt/gpu/compute.cu | 4 ++ torchaudio/csrc/rnnt/gpu/compute_alphas.cu | 4 ++ torchaudio/csrc/rnnt/gpu/compute_betas.cu | 4 ++ torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh | 12 ++++ torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh | 22 ++++++++ torchaudio/csrc/rnnt/gpu/gpu_transducer.h | 5 ++ torchaudio/csrc/rnnt/gpu/kernel_utils.h | 4 ++ torchaudio/csrc/rnnt/gpu/kernels.h | 5 ++ torchaudio/csrc/rnnt/macros.h | 8 +++ torchaudio/csrc/rnnt/options.h | 9 ++- torchaudio/csrc/rnnt/workspace.h | 16 +++++- 15 files changed, 157 insertions(+), 5 deletions(-) create mode 160000 third_party/hipify_torch diff --git a/.gitmodules b/.gitmodules index e69de29bb2d..25d307cea8d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "third_party/hipify_torch"] + path = third_party/hipify_torch + url = https://github.com/ROCmSoftwarePlatform/hipify_torch diff --git a/CMakeLists.txt b/CMakeLists.txt index f0195c87b43..d955c5da589 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,6 +76,11 @@ if(USE_ROCM) if(NOT PYTORCH_FOUND_HIP) set(USE_ROCM OFF) endif() + + if(CMAKE_VERSION VERSION_LESS 3.21.0) + message("Need at least CMake 3.21.0 to compile ROCm support.") + set(USE_ROCM OFF) + endif() endif() if(USE_CUDA) @@ -90,6 +95,11 @@ if(USE_CUDA) ) endif() +if(USE_ROCM) + enable_language(HIP) +endif() + +find_package(Torch REQUIRED) include(cmake/TorchAudioHelper.cmake) # https://github.com/pytorch/pytorch/issues/54174 diff --git a/third_party/hipify_torch b/third_party/hipify_torch new file mode 160000 index 00000000000..083ff9b50c7 --- /dev/null +++ b/third_party/hipify_torch @@ -0,0 +1 @@ +Subproject commit 083ff9b50c7ed861f7f6eddd983cdedb72e8b964 diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index fc0c549493d..f05534f647b 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -1,6 +1,29 @@ ################################################################################ # libtorchaudio ################################################################################ + +if(USE_ROCM) + list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) + FIND_PACKAGE(HIP REQUIRED) + MESSAGE(STATUS "hip found ${ROCM_FOUND}") + + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/third_party/hipify_torch/cmake") + include(Hipify) + + set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE}) + set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) + list( APPEND CMAKE_INSTALL_RPATH "/opt/rocm/llvm/lib" ) + set(OPENMP_LIBRARIES "/opt/rocm/llvm/lib/") + set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") + set(OpenMP_CXX_FLAGS "-fopenmp=libomp") + #set(OpenMP_CXX_LIB_NAMES "omp") + set(OpenMP_omp_LIBRARY omp) + find_package(OpenMP REQUIRED) + +endif() + + set( sources lfilter.cpp @@ -39,6 +62,37 @@ if(BUILD_RNNT) rnnt/gpu/compute.cu ) endif() + + if (USE_ROCM) + hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR}/torchaudio/csrc/rnnt/gpu HIP_SOURCE_DIR ${PROJECT_SOURCE_DIR}/torchaudio/csrc/rnnt/hip) + if ( NOT HIP_ADD_LIBRARY_FOUND ) + list(APPEND CMAKE_MODULE_PATH /opt/rocm/hip/cmake) + find_package(HIP REQUIRED) + endif() + + list( + APPEND + sources + rnnt/hip/compute_alphas.hip + rnnt/hip/compute_betas.hip + rnnt/hip/compute.hip + ) + endif() +endif() + +if(USE_ROCM) + list( + APPEND + additional_libs + hip::host + hip::device + /opt/rocm/llvm/lib/libomp.so + ) + list( + APPEND + compile_definitions + USE_ROCM + ) endif() if(BUILD_RIR) @@ -87,7 +141,6 @@ endif() #------------------------------------------------------------------------------# # END OF CUSTOMIZATION LOGICS #------------------------------------------------------------------------------# - torchaudio_library( libtorchaudio "${sources}" diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu index a6d389bf0bc..33030538220 100644 --- a/torchaudio/csrc/rnnt/gpu/compute.cu +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -1,6 +1,10 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu index 918d442bf04..22706f670db 100644 --- a/torchaudio/csrc/rnnt/gpu/compute_alphas.cu +++ b/torchaudio/csrc/rnnt/gpu/compute_alphas.cu @@ -1,6 +1,10 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/torchaudio/csrc/rnnt/gpu/compute_betas.cu b/torchaudio/csrc/rnnt/gpu/compute_betas.cu index e1e4c1d90eb..d2a6134181d 100644 --- a/torchaudio/csrc/rnnt/gpu/compute_betas.cu +++ b/torchaudio/csrc/rnnt/gpu/compute_betas.cu @@ -1,6 +1,10 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh index e5f1cfc2df3..cb3c6157704 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh @@ -2,7 +2,11 @@ #ifdef USE_CUDA +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { @@ -39,7 +43,11 @@ __global__ void ReduceMax2D( CAST_DTYPE shf; for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { +#ifdef __HIP_PLATFORM_AMD__ + shf = __shfl_down(val, stride); +#else shf = __shfl_down_sync(0xFFFFFFFF, val, stride); +#endif if (threadIdx.x < stride && threadIdx.x + stride < dim) { if (shf > val) { val = shf; @@ -81,7 +89,11 @@ __global__ void ReduceLogSumExpGivenMax2D( CAST_DTYPE shf; for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { +#ifdef __HIP_PLATFORM_AMD__ + shf = __shfl_down(val, stride); +#else shf = __shfl_down_sync(0xFFFFFFFF, val, stride); +#endif if (threadIdx.x < stride && threadIdx.x + stride < dim) { val = val + shf; } diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh index 4ba04b68fca..2b7ef45df3b 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh @@ -4,9 +4,15 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#include +#else #include #include #include +#endif namespace torchaudio { namespace rnnt { @@ -126,7 +132,11 @@ __device__ void ComputeAlphas( #pragma unroll for (int i = 1; i < warpSize; i <<= 1) { +#ifdef __HIP_PLATFORM_AMD__ + val = __shfl_up(skip_prob, i); +#else val = __shfl_up_sync(0xffffffff, skip_prob, i); +#endif if (i <= threadIdx.x) { skip_prob = skip_prob + val; } @@ -150,7 +160,11 @@ __device__ void ComputeAlphas( CAST_DTYPE out = val; for (int i = 1; i < warpSize; ++i) { +#ifdef __HIP_PLATFORM_AMD__ + val = __shfl_up(val, 1); +#else val = __shfl_up_sync(0xffffffff, val, 1); +#endif if (i == threadIdx.x) { val = math::lse(val + skip_prob, emit); out = val; @@ -225,7 +239,11 @@ __device__ void ComputeBetasCosts( #pragma unroll for (int i = 1; i < warpSize; i <<= 1) { +#ifdef __HIP_PLATFORM_AMD__ + val = __shfl_up(skip_prob, i); +#else val = __shfl_up_sync(0xffffffff, skip_prob, i); +#endif if (i <= threadIdx.x) { skip_prob = skip_prob + val; } @@ -248,7 +266,11 @@ __device__ void ComputeBetasCosts( CAST_DTYPE out = val; for (int i = 1; i < warpSize; ++i) { +#ifdef __HIP_PLATFORM_AMD__ + val = __shfl_up(val, 1); +#else val = __shfl_up_sync(0xffffffff, val, 1); +#endif if (i == threadIdx.x) { val = math::lse(val + skip_prob, emit); out = val; diff --git a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h index 72759b39f41..32a731bafd7 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h +++ b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h @@ -3,8 +3,13 @@ #ifdef USE_CUDA #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#else #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/torchaudio/csrc/rnnt/gpu/kernel_utils.h b/torchaudio/csrc/rnnt/gpu/kernel_utils.h index 3b2989b0737..68136fcfa3b 100644 --- a/torchaudio/csrc/rnnt/gpu/kernel_utils.h +++ b/torchaudio/csrc/rnnt/gpu/kernel_utils.h @@ -2,7 +2,11 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#else #include +#endif namespace torchaudio { namespace rnnt { diff --git a/torchaudio/csrc/rnnt/gpu/kernels.h b/torchaudio/csrc/rnnt/gpu/kernels.h index db8bb5092b5..d22443fecb5 100644 --- a/torchaudio/csrc/rnnt/gpu/kernels.h +++ b/torchaudio/csrc/rnnt/gpu/kernels.h @@ -2,8 +2,13 @@ #include +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#else #include #include +#endif namespace torchaudio { namespace rnnt { diff --git a/torchaudio/csrc/rnnt/macros.h b/torchaudio/csrc/rnnt/macros.h index abcbc399664..e569d262413 100644 --- a/torchaudio/csrc/rnnt/macros.h +++ b/torchaudio/csrc/rnnt/macros.h @@ -8,6 +8,14 @@ #define FORCE_INLINE __forceinline__ #include #include +#elif USE_ROCM +#define WARP_SIZE 32 +#define MAX_THREADS_PER_BLOCK 1024 +#define REDUCE_THREADS 256 +#define HOST_AND_DEVICE __host__ __device__ +#define FORCE_INLINE __forceinline__ +#include +#include #else #define HOST_AND_DEVICE #define FORCE_INLINE inline diff --git a/torchaudio/csrc/rnnt/options.h b/torchaudio/csrc/rnnt/options.h index f70a3c8c07b..ecf0714a3c3 100644 --- a/torchaudio/csrc/rnnt/options.h +++ b/torchaudio/csrc/rnnt/options.h @@ -4,7 +4,12 @@ #ifdef USE_CUDA #include +typedef cudaStream_t gpuStream_t; #endif // USE_CUDA +#ifdef USE_ROCM +#include +typedef hipStream_t gpuStream_t; +#endif // USE_ROCM #include #include @@ -15,9 +20,9 @@ namespace rnnt { typedef struct Options { // the device to compute transducer loss. device_t device_; -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) // the stream to launch kernels in when using GPU. - cudaStream_t stream_; + gpuStream_t stream_; #endif // The maximum number of threads that can be used. int numThreads_; diff --git a/torchaudio/csrc/rnnt/workspace.h b/torchaudio/csrc/rnnt/workspace.h index e833ef1cdff..14ae0047ba0 100644 --- a/torchaudio/csrc/rnnt/workspace.h +++ b/torchaudio/csrc/rnnt/workspace.h @@ -131,10 +131,22 @@ class IntWorkspace { ComputeSizeForBetaCounters(options_) * sizeof(int)); } #endif // USE_CUDA +#ifdef USE_ROCM + if (data_ != nullptr && options_.device_ == GPU) { + hipMemset( + GetPointerToAlphaCounters(), + 0, + ComputeSizeForAlphaCounters(options_) * sizeof(int)); + hipMemset( + GetPointerToBetaCounters(), + 0, + ComputeSizeForBetaCounters(options_) * sizeof(int)); + } +#endif // USE_ROCM } static int ComputeSizeForAlphaCounters(const Options& options) { // B * U -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) if (options.device_ == GPU) { return options.BU(); } else { @@ -145,7 +157,7 @@ class IntWorkspace { #endif // USE_CUDA } static int ComputeSizeForBetaCounters(const Options& options) { // B * U -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) if (options.device_ == GPU) { return options.BU(); } else {