Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions csrc/sampler.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#include <torch/cuda.h>
Expand Down Expand Up @@ -97,7 +98,9 @@ static inline __device__ bool isPartialMatch(float x, uint32_t pattern) {
template <typename T, typename idxT, typename Func>
__device__ void vectorized_process(size_t thread_rank, size_t num_threads,
const T* in, idxT len, Func f) {
constexpr int WARP_SIZE = 32;
// Use dynamic WARP_SIZE from cuda_compat.h to support both
// Wave64 (MI300X/gfx942) and Wave32 (Strix Halo/gfx1151) architectures
constexpr int kWarpSize = WARP_SIZE;
using WideT = float4;
if constexpr (sizeof(T) >= sizeof(WideT)) {
for (idxT i = thread_rank; i < len; i += num_threads) {
Expand Down Expand Up @@ -132,8 +135,8 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads,
}
}

static_assert(WARP_SIZE >= items_per_scalar);
// and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt
static_assert(kWarpSize >= items_per_scalar);
// and because items_per_scalar > skip_cnt, kWarpSize > skip_cnt
// no need to use loop
if (thread_rank < skip_cnt) {
f(in[thread_rank], thread_rank);
Expand All @@ -142,7 +145,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads,
// len_cast * items_per_scalar + items_per_scalar > len - skip_cnt;
// and so
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <=
// WARP_SIZE no need to use loop
// kWarpSize no need to use loop
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
if (remain_i < len) {
f(in[remain_i], remain_i);
Expand Down