diff --git a/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh b/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh index 0b74f89550..03bc2a3cb1 100644 --- a/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh @@ -10,20 +10,27 @@ #include "gpu_iface/platform.hpp" +// Define platform-specific full mask for warp/wavefront operations +#if defined(PLATFORM_CUDA_DEVICE) +constexpr uint32_t WARP_FULL_MASK = 0xffffffff; // 32-bit mask for CUDA +#elif defined(PLATFORM_HIP_DEVICE) +constexpr uint64_t WARP_FULL_MASK = 0xffffffffffffffffULL; // 64-bit mask for HIP +#endif + __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); + uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x1); x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276); - tmp = __shfl_xor_sync(0xffffffff, x, 0x2); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x2); x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276); return x; } __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); + uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x4); x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175); - tmp = __shfl_xor_sync(0xffffffff, x, 0x8); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x8); x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); - tmp = __shfl_xor_sync(0xffffffff, x, 0x10); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x10); x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); return x; }