Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,27 @@

#include "gpu_iface/platform.hpp"

// Define platform-specific full mask for warp/wavefront operations
#if defined(PLATFORM_CUDA_DEVICE)
constexpr uint32_t WARP_FULL_MASK = 0xffffffff; // 32-bit mask for CUDA
#elif defined(PLATFORM_HIP_DEVICE)
constexpr uint64_t WARP_FULL_MASK = 0xffffffffffffffffULL; // 64-bit mask for HIP
#endif

__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) {
uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1);
uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x1);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276);
tmp = __shfl_xor_sync(0xffffffff, x, 0x2);
tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x2);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276);
return x;
}

__device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) {
uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4);
uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x4);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175);
tmp = __shfl_xor_sync(0xffffffff, x, 0x8);
tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x8);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276);
tmp = __shfl_xor_sync(0xffffffff, x, 0x10);
tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x10);
x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276);
return x;
}
Expand Down