diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 5e2f51f933c6..4e48ed5f77fb 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -17,14 +17,30 @@ // Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which // together enable 256-bit (v8.u32) PTX load/store instructions. -// Use for PTX instruction selection with architecture fallback paths. +// ROCm gfx950+ (CDNA4/MI350X/MI355X): 256-bit logical width via 2× dwordx4. +// Use for PTX/ISA instruction selection with architecture fallback paths. #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ defined(CUDA_VERSION) && CUDA_VERSION >= 12090 #define VLLM_256B_PTX_ENABLED 1 +// gfx950 (CDNA4): 256-bit path disabled for now. The gfx950 ISA (ROCm 7.2, +// LLVM 22.0) lacks vector dwordx8 load/store instructions, so the compiler +// decomposes 256-bit ops into 2× dwordx4 issued in separate cycles, leading +// to uncoalesced memory accesses within a warp. Re-enable once a future +// ROCm version adds native global_load/store_dwordx8 support. +// See: https://github.com/vllm-project/vllm/pull/36743#discussion_r2048743213 +// #elif defined(USE_ROCM) && defined(__gfx950__) +// #define VLLM_256B_PTX_ENABLED 1 #else #define VLLM_256B_PTX_ENABLED 0 #endif +// ROCm gfx942+ (CDNA3/MI300X): non-temporal hints for cache bypass. +#if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx950__)) + #define VLLM_ROCM_USE_NT_HINTS 1 +#else + #define VLLM_ROCM_USE_NT_HINTS 0 +#endif + namespace vllm { // ============================================================ @@ -133,28 +149,59 @@ struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { // Load / store primitives // ============================================================ -// 256-bit load / store — SM100+ only (PTX v8 instructions). +// 256-bit load / store — SM100+ (PTX v8) or ROCm gfx950+ (2× dwordx4). __device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { #if VLLM_256B_PTX_ENABLED + #if defined(USE_ROCM) && defined(__gfx950__) + // gfx950 (CDNA4): 256-bit logical load. The gfx950 ISA does not have + // global_load_dwordx8 — the compiler emits 2× global_load_dwordx4 with + // adjacent offsets (off + off:16). This still halves loop iterations vs + // the 128-bit path. Verified on MI350X, ROCm 7.2 / LLVM 22.0. + // If a future ROCm adds vector dwordx8, this code benefits automatically. + const uint32_t* src = reinterpret_cast(ptr); + val.d[0] = src[0]; + val.d[1] = src[1]; + val.d[2] = src[2]; + val.d[3] = src[3]; + val.d[4] = src[4]; + val.d[5] = src[5]; + val.d[6] = src[6]; + val.d[7] = src[7]; + #else asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "l"(ptr)); + #endif #else - assert(false && "ld256 requires SM100+ with CUDA 12.9+"); + assert(false && "ld256 requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); #endif } __device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { #if VLLM_256B_PTX_ENABLED + #if defined(USE_ROCM) && defined(__gfx950__) + // gfx950 (CDNA4): 256-bit logical store → 2× global_store_dwordx4. + // See ld256 comment for ISA details (verified on MI350X, ROCm 7.2). + uint32_t* dst = reinterpret_cast(ptr); + dst[0] = val.d[0]; + dst[1] = val.d[1]; + dst[2] = val.d[2]; + dst[3] = val.d[3]; + dst[4] = val.d[4]; + dst[5] = val.d[5]; + dst[6] = val.d[6]; + dst[7] = val.d[7]; + #else asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" : : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]) : "memory"); + #endif #else - assert(false && "st256 requires SM100+ with CUDA 12.9+"); + assert(false && "st256 requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); #endif } @@ -185,29 +232,58 @@ __device__ __forceinline__ void st128(T& val, T* ptr) { *reinterpret_cast(ptr) = *reinterpret_cast(&val); } -// 256-bit cache-streaming (.cs) load / store — SM100+ only. +// 256-bit cache-streaming (.cs) load / store. +// SM100+: PTX .cs hint. ROCm gfx950+: slc (system-level coherent) hint. __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { #if VLLM_256B_PTX_ENABLED u32x8_t val; + #if defined(USE_ROCM) && defined(__gfx950__) + // gfx950 (CDNA4): 256-bit non-temporal load for streaming. + // The compiler coalesces sequential __builtin_nontemporal_load calls + // into vectorized global_load instructions at -O3 (verified on gfx942 + // where 4× scalar NT → single global_load_dwordx4 nt). + // On gfx950: emits 2× global_load_dwordx4 nt (no vector dwordx8 in ISA). + // Verified on MI350X, ROCm 7.2 / LLVM 22.0. + // TODO: Check if future ROCm versions add a vector dwordx8 instruction to the + // gfx950 ISA. + const uint32_t* src = reinterpret_cast(addr); + for (int i = 0; i < 8; i++) { + val.d[i] = + __builtin_nontemporal_load(reinterpret_cast(src) + i); + } + #else asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "l"(addr)); + #endif return val; #else - assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); + assert(false && "ld256_cs requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); return u32x8_t{}; #endif } __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { #if VLLM_256B_PTX_ENABLED + #if defined(USE_ROCM) && defined(__gfx950__) + // gfx950 (CDNA4): 256-bit non-temporal store for streaming write. + // Emits 2× global_store_dwordx4 nt (no vector dwordx8 in gfx950 ISA). + // Verified on MI350X, ROCm 7.2 / LLVM 22.0. + // TODO: Check if future ROCm versions add a vector dwordx8 instruction to the + // gfx950 ISA. + int* dst = reinterpret_cast(addr); + for (int i = 0; i < 8; i++) { + __builtin_nontemporal_store(static_cast(val.d[i]), dst + i); + } + #else asm volatile( "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); + #endif #else - assert(false && "st256_cs requires SM100+ with CUDA 12.9+"); + assert(false && "st256_cs requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); #endif } @@ -217,7 +293,8 @@ __device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); } __device__ __forceinline__ void st32(int* addr, int val) { *addr = val; } // 32-bit cache-streaming (.cs) load / store. -// Falls back to ld32/st32 on ROCm (no .cs hint). +// ROCm gfx942+: uses __builtin_nontemporal_store for write-side bypass. +// Falls back to ld32/st32 on other ROCm targets. __forceinline__ __device__ int ld32_cs(const int* addr) { int val; #ifndef USE_ROCM @@ -231,13 +308,18 @@ __forceinline__ __device__ int ld32_cs(const int* addr) { __forceinline__ __device__ void st32_cs(int* addr, int val) { #ifndef USE_ROCM asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); +#elif VLLM_ROCM_USE_NT_HINTS + __builtin_nontemporal_store(val, addr); #else st32(addr, val); #endif } // 128-bit cache-streaming (.cs) load / store. -// Falls back to ld128/st128 on ROCm (no .cs hint). +// ROCm gfx942+: uses __builtin_nontemporal_store for write-side cache +// bypass. Reads use normal loads to benefit from L2 cache on small/medium +// token counts (NT load hurts latency for batches < 2048). +// Falls back to ld128/st128 on other ROCm targets. __forceinline__ __device__ int4 ld128_cs(const int4* addr) { int4 val; #ifndef USE_ROCM @@ -254,6 +336,12 @@ __forceinline__ __device__ void st128_cs(int4* addr, int4 val) { #ifndef USE_ROCM asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +#elif VLLM_ROCM_USE_NT_HINTS + int* dst = reinterpret_cast(addr); + __builtin_nontemporal_store(val.x, dst); + __builtin_nontemporal_store(val.y, dst + 1); + __builtin_nontemporal_store(val.z, dst + 2); + __builtin_nontemporal_store(val.w, dst + 3); #else st128(val, addr); #endif