Skip to content

Commit

Permalink
Revert "Revert "Enable ROCm RNN-T Loss (pytorch#2485)" (pytorch#3586)"
Browse files Browse the repository at this point in the history
This reverts commit 49d3eec.
  • Loading branch information
pruthvistony committed Nov 4, 2023
1 parent 5784206 commit d3deb2e
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/hipify_torch"]
path = third_party/hipify_torch
url = https://github.com/ROCmSoftwarePlatform/hipify_torch
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ if(USE_ROCM)
if(NOT PYTORCH_FOUND_HIP)
set(USE_ROCM OFF)
endif()

if(CMAKE_VERSION VERSION_LESS 3.21.0)
message("Need at least CMake 3.21.0 to compile ROCm support.")
set(USE_ROCM OFF)
endif()
endif()

if(USE_CUDA)
Expand All @@ -90,6 +95,11 @@ if(USE_CUDA)
)
endif()

if(USE_ROCM)
enable_language(HIP)
endif()

find_package(Torch REQUIRED)
include(cmake/TorchAudioHelper.cmake)

# https://github.com/pytorch/pytorch/issues/54174
Expand Down
1 change: 1 addition & 0 deletions third_party/hipify_torch
Submodule hipify_torch added at 083ff9
55 changes: 54 additions & 1 deletion torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
################################################################################
# libtorchaudio
################################################################################

if(USE_ROCM)
list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm)
FIND_PACKAGE(HIP REQUIRED)
MESSAGE(STATUS "hip found ${ROCM_FOUND}")

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/third_party/hipify_torch/cmake")
include(Hipify)

set(CMAKE_CXX_COMPILER ${HIP_HIPCC_EXECUTABLE})
set(CMAKE_CXX_LINKER ${HIP_HIPCC_EXECUTABLE})
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
list( APPEND CMAKE_INSTALL_RPATH "/opt/rocm/llvm/lib" )
set(OPENMP_LIBRARIES "/opt/rocm/llvm/lib/")
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
set(OpenMP_CXX_FLAGS "-fopenmp=libomp")
#set(OpenMP_CXX_LIB_NAMES "omp")
set(OpenMP_omp_LIBRARY omp)
find_package(OpenMP REQUIRED)

endif()


set(
sources
lfilter.cpp
Expand Down Expand Up @@ -39,6 +62,37 @@ if(BUILD_RNNT)
rnnt/gpu/compute.cu
)
endif()

if (USE_ROCM)
hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR}/torchaudio/csrc/rnnt/gpu HIP_SOURCE_DIR ${PROJECT_SOURCE_DIR}/torchaudio/csrc/rnnt/hip)
if ( NOT HIP_ADD_LIBRARY_FOUND )
list(APPEND CMAKE_MODULE_PATH /opt/rocm/hip/cmake)
find_package(HIP REQUIRED)
endif()

list(
APPEND
sources
rnnt/hip/compute_alphas.hip
rnnt/hip/compute_betas.hip
rnnt/hip/compute.hip
)
endif()
endif()

if(USE_ROCM)
list(
APPEND
additional_libs
hip::host
hip::device
/opt/rocm/llvm/lib/libomp.so
)
list(
APPEND
compile_definitions
USE_ROCM
)
endif()

if(BUILD_RIR)
Expand Down Expand Up @@ -87,7 +141,6 @@ endif()
#------------------------------------------------------------------------------#
# END OF CUSTOMIZATION LOGICS
#------------------------------------------------------------------------------#

torchaudio_library(
libtorchaudio
"${sources}"
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/types.h>
#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/gpu_transducer_hip.h>
#else
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/rnnt/gpu/compute_alphas.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/types.h>
#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/gpu_transducer_hip.h>
#else
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/rnnt/gpu/compute_betas.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/types.h>
#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/gpu_transducer_hip.h>
#else
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
12 changes: 12 additions & 0 deletions torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

#ifdef USE_CUDA

#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/math_hip.cuh>
#else
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down Expand Up @@ -39,7 +43,11 @@ __global__ void ReduceMax2D(

CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
#ifdef __HIP_PLATFORM_AMD__
shf = __shfl_down(val, stride);
#else
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
#endif
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
if (shf > val) {
val = shf;
Expand Down Expand Up @@ -81,7 +89,11 @@ __global__ void ReduceLogSumExpGivenMax2D(

CAST_DTYPE shf;
for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) {
#ifdef __HIP_PLATFORM_AMD__
shf = __shfl_down(val, stride);
#else
shf = __shfl_down_sync(0xFFFFFFFF, val, stride);
#endif
if (threadIdx.x < stride && threadIdx.x + stride < dim) {
val = val + shf;
}
Expand Down
22 changes: 22 additions & 0 deletions torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

#include <cassert>

#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/kernel_utils.h>
#include <torchaudio/csrc/rnnt/hip/kernels.h>
#include <torchaudio/csrc/rnnt/hip/math_hip.cuh>
#else
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/kernels.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down Expand Up @@ -126,7 +132,11 @@ __device__ void ComputeAlphas(

#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
#ifdef __HIP_PLATFORM_AMD__
val = __shfl_up(skip_prob, i);
#else
val = __shfl_up_sync(0xffffffff, skip_prob, i);
#endif
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
Expand All @@ -150,7 +160,11 @@ __device__ void ComputeAlphas(
CAST_DTYPE out = val;

for (int i = 1; i < warpSize; ++i) {
#ifdef __HIP_PLATFORM_AMD__
val = __shfl_up(val, 1);
#else
val = __shfl_up_sync(0xffffffff, val, 1);
#endif
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
Expand Down Expand Up @@ -225,7 +239,11 @@ __device__ void ComputeBetasCosts(

#pragma unroll
for (int i = 1; i < warpSize; i <<= 1) {
#ifdef __HIP_PLATFORM_AMD__
val = __shfl_up(skip_prob, i);
#else
val = __shfl_up_sync(0xffffffff, skip_prob, i);
#endif
if (i <= threadIdx.x) {
skip_prob = skip_prob + val;
}
Expand All @@ -248,7 +266,11 @@ __device__ void ComputeBetasCosts(
CAST_DTYPE out = val;

for (int i = 1; i < warpSize; ++i) {
#ifdef __HIP_PLATFORM_AMD__
val = __shfl_up(val, 1);
#else
val = __shfl_up_sync(0xffffffff, val, 1);
#endif
if (i == threadIdx.x) {
val = math::lse(val + skip_prob, emit);
out = val;
Expand Down
5 changes: 5 additions & 0 deletions torchaudio/csrc/rnnt/gpu/gpu_transducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
#ifdef USE_CUDA

#include <torchaudio/csrc/rnnt/workspace.h>
#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/gpu_kernel_utils_hip.cuh>
#include <torchaudio/csrc/rnnt/hip/gpu_kernels_hip.cuh>
#else
#include <torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/rnnt/gpu/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

#include <cassert>

#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/math_hip.cuh>
#else
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
5 changes: 5 additions & 0 deletions torchaudio/csrc/rnnt/gpu/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

#include <cassert>

#ifdef __HIP_PLATFORM_AMD__
#include <torchaudio/csrc/rnnt/hip/kernel_utils.h>
#include <torchaudio/csrc/rnnt/hip/math_hip.cuh>
#else
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
#endif

namespace torchaudio {
namespace rnnt {
Expand Down
8 changes: 8 additions & 0 deletions torchaudio/csrc/rnnt/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
#define FORCE_INLINE __forceinline__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#elif USE_ROCM
#define WARP_SIZE 32
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
#define HOST_AND_DEVICE __host__ __device__
#define FORCE_INLINE __forceinline__
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#else
#define HOST_AND_DEVICE
#define FORCE_INLINE inline
Expand Down
9 changes: 7 additions & 2 deletions torchaudio/csrc/rnnt/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

#ifdef USE_CUDA
#include <cuda_runtime.h>
typedef cudaStream_t gpuStream_t;
#endif // USE_CUDA
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
typedef hipStream_t gpuStream_t;
#endif // USE_ROCM

#include <torchaudio/csrc/rnnt/macros.h>
#include <torchaudio/csrc/rnnt/types.h>
Expand All @@ -15,9 +20,9 @@ namespace rnnt {
typedef struct Options {
// the device to compute transducer loss.
device_t device_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// the stream to launch kernels in when using GPU.
cudaStream_t stream_;
gpuStream_t stream_;
#endif
// The maximum number of threads that can be used.
int numThreads_;
Expand Down
16 changes: 14 additions & 2 deletions torchaudio/csrc/rnnt/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,22 @@ class IntWorkspace {
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_CUDA
#ifdef USE_ROCM
if (data_ != nullptr && options_.device_ == GPU) {
hipMemset(
GetPointerToAlphaCounters(),
0,
ComputeSizeForAlphaCounters(options_) * sizeof(int));
hipMemset(
GetPointerToBetaCounters(),
0,
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_ROCM
}

static int ComputeSizeForAlphaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
if (options.device_ == GPU) {
return options.BU();
} else {
Expand All @@ -145,7 +157,7 @@ class IntWorkspace {
#endif // USE_CUDA
}
static int ComputeSizeForBetaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
if (options.device_ == GPU) {
return options.BU();
} else {
Expand Down

0 comments on commit d3deb2e

Please sign in to comment.