Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Hipify torchaudio_decoder #3843

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/libtorchaudio/cuctc/include/ctc_prefix_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#ifndef __ctc_prefix_decoder_h_
#define __ctc_prefix_decoder_h_

#include <cuda_runtime.h>
#include <cstdint>
#include <tuple>
#include <vector>
Expand Down
18 changes: 0 additions & 18 deletions src/libtorchaudio/cuctc/include/ctc_prefix_decoder_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,6 @@
#ifndef __ctc_prefix_decoder_host_h_
#define __ctc_prefix_decoder_host_h_

#include <cuda_runtime.h>

#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)

#define CHECK(X, ERROR_INFO) \
do { \
auto result = (X); \
Expand Down
4 changes: 4 additions & 0 deletions src/libtorchaudio/cuctc/src/bitonic_topk/bitonic_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ constexpr inline __host__ __device__ bool isPo2(IntType num) {
}

inline __device__ int laneId() {
#ifndef USE_ROCM
int id;
asm("mov.s32 %0, %%laneid;" : "=r"(id));
return id;
#else
return __lane_id();
#endif
}
/**
* @brief Shuffle the data inside a warp
Expand Down
3 changes: 2 additions & 1 deletion src/libtorchaudio/cuctc/src/bitonic_topk/pow2_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ namespace cu_ctc {
* @tparam IntType data type (checked only for integers)
*/
template <typename IntType>
constexpr __device__ IntType log2(IntType num, IntType ret = IntType(0)) {
constexpr __host__ __device__ IntType
log2(IntType num, IntType ret = IntType(0)) {
return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret);
}

Expand Down
8 changes: 4 additions & 4 deletions src/libtorchaudio/cuctc/src/bitonic_topk/warpsort_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {

__device__ __forceinline__ void merge_buf_() {
topk::bitonic<kMaxBufLen>(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_);
this->merge_in<kMaxBufLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxBufLen>(val_buf_, idx_buf_);
buf_len_ = 0;
set_k_th_(); // contains warp sync
#pragma unroll
Expand Down Expand Up @@ -385,7 +385,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ == kMaxArrLen) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
#pragma unroll
for (int i = 0; i < kMaxArrLen; i++) {
val_buf_[i] = kDummy;
Expand All @@ -398,7 +398,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
if (buf_len_ != 0) {
topk::bitonic<kMaxArrLen>(!Ascending, kWarpWidth)
.sort(val_buf_, idx_buf_);
this->merge_in<kMaxArrLen>(val_buf_, idx_buf_);
this->template merge_in<kMaxArrLen>(val_buf_, idx_buf_);
}
}

Expand All @@ -421,7 +421,7 @@ constexpr inline __host__ __device__ IntType ceildiv(IntType a, IntType b) {
return (a + b - 1) / b;
}
template <typename IntType>
constexpr inline __device__ IntType roundUp256(IntType num) {
constexpr inline __host__ __device__ IntType roundUp256(IntType num) {
// return (num + 255) / 256 * 256;
constexpr int MASK = 255;
return (num + MASK) & (~MASK);
Expand Down
4 changes: 2 additions & 2 deletions src/libtorchaudio/cuctc/src/ctc_prefix_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cuda_runtime.h>

#include "include/ctc_prefix_decoder.h"
#include "include/ctc_prefix_decoder_host.h"
#include "../include/ctc_prefix_decoder.h"
#include "../include/ctc_prefix_decoder_host.h"

#include "device_data_wrap.h"
#include "device_log_prob.cuh"
Expand Down
9 changes: 6 additions & 3 deletions src/libtorchaudio/cuctc/src/ctc_prefix_decoder_kernel_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <float.h>
#include <algorithm>
#include "../include/ctc_prefix_decoder_host.h"
#include "ctc_fast_divmod.cuh"
#include "cub/cub.cuh"
#include "device_data_wrap.h"
#include "device_log_prob.cuh"
#include "include/ctc_prefix_decoder_host.h"

#include "bitonic_topk/warpsort_topk.cuh"

Expand Down Expand Up @@ -630,7 +631,8 @@ int CTC_prob_first_step_V2(
num_of_subwarp, beam));
int smem_size =
block_sort_smem_size + beam * sizeof(float) + beam * sizeof(int);
FirstMatrixFuns[fun_idx]<<<grid, threads_per_block, smem_size, stream>>>(
auto kernel = FirstMatrixFuns[fun_idx];
kernel<<<grid, threads_per_block, smem_size, stream>>>(
(*log_prob_struct),
step,
pprev,
Expand Down Expand Up @@ -766,7 +768,8 @@ int CTC_prob_topK_V2(
int num_of_subwarp = threads_per_block0 / std::min<int>(32, actual_capacity);
int smem_size = cu_ctc::topk::calc_smem_size_for_block_wide<float, int>(
num_of_subwarp, beam);
BitonicTopkFuns[fun_idx]<<<grid, block, smem_size, stream>>>(
auto kernel = BitonicTopkFuns[fun_idx];
kernel<<<grid, block, smem_size, stream>>>(
(*log_prob_struct),
step,
ptable,
Expand Down
19 changes: 18 additions & 1 deletion src/libtorchaudio/cuctc/src/device_data_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,26 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <cuda_runtime.h>
#include <iostream>
#include <vector>
#include "include/ctc_prefix_decoder_host.h"
#include "../include/ctc_prefix_decoder_host.h"

#define CUDA_CHECK(X) \
do { \
auto result = X; \
if (result != cudaSuccess) { \
const char* p_err_str = cudaGetErrorName(result); \
fprintf( \
stderr, \
"File %s Line %d %s returned %s.\n", \
__FILE__, \
__LINE__, \
#X, \
p_err_str); \
abort(); \
} \
} while (0)

namespace cu_ctc {
constexpr size_t ALIGN_BYTES = 128;
Expand Down
Loading