Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 97 additions & 9 deletions csrc/cuda_vec_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,30 @@

// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which
// together enable 256-bit (v8.u32) PTX load/store instructions.
// Use for PTX instruction selection with architecture fallback paths.
// ROCm gfx950+ (CDNA4/MI350X/MI355X): 256-bit logical width via 2× dwordx4.
// Use for PTX/ISA instruction selection with architecture fallback paths.
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \
defined(CUDA_VERSION) && CUDA_VERSION >= 12090
#define VLLM_256B_PTX_ENABLED 1
// gfx950 (CDNA4): 256-bit path disabled for now. The gfx950 ISA (ROCm 7.2,
// LLVM 22.0) lacks vector dwordx8 load/store instructions, so the compiler
// decomposes 256-bit ops into 2× dwordx4 issued in separate cycles, leading
// to uncoalesced memory accesses within a warp. Re-enable once a future
// ROCm version adds native global_load/store_dwordx8 support.
// See: https://github.com/vllm-project/vllm/pull/36743#discussion_r2048743213
// #elif defined(USE_ROCM) && defined(__gfx950__)
// #define VLLM_256B_PTX_ENABLED 1
#else
#define VLLM_256B_PTX_ENABLED 0
#endif

// ROCm gfx942+ (CDNA3/MI300X): non-temporal hints for cache bypass.
#if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx950__))
#define VLLM_ROCM_USE_NT_HINTS 1
#else
#define VLLM_ROCM_USE_NT_HINTS 0
#endif

namespace vllm {

// ============================================================
Expand Down Expand Up @@ -133,28 +149,59 @@ struct alignas(VecTraits<use_256b>::ARCH_MAX_VEC_SIZE) PackedVec {
// Load / store primitives
// ============================================================

// 256-bit load / store — SM100+ only (PTX v8 instructions).
// 256-bit load / store — SM100+ (PTX v8) or ROCm gfx950+ (2× dwordx4).
__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) {
#if VLLM_256B_PTX_ENABLED
#if defined(USE_ROCM) && defined(__gfx950__)
// gfx950 (CDNA4): 256-bit logical load. The gfx950 ISA does not have
// global_load_dwordx8 — the compiler emits 2× global_load_dwordx4 with
// adjacent offsets (off + off:16). This still halves loop iterations vs
// the 128-bit path. Verified on MI350X, ROCm 7.2 / LLVM 22.0.
// If a future ROCm adds vector dwordx8, this code benefits automatically.
const uint32_t* src = reinterpret_cast<const uint32_t*>(ptr);
val.d[0] = src[0];
val.d[1] = src[1];
val.d[2] = src[2];
val.d[3] = src[3];
val.d[4] = src[4];
val.d[5] = src[5];
val.d[6] = src[6];
val.d[7] = src[7];
#else
asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n"
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
: "l"(ptr));
#endif
#else
assert(false && "ld256 requires SM100+ with CUDA 12.9+");
assert(false && "ld256 requires SM100+ with CUDA 12.9+ or ROCm gfx950+");
#endif
}

__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) {
#if VLLM_256B_PTX_ENABLED
#if defined(USE_ROCM) && defined(__gfx950__)
// gfx950 (CDNA4): 256-bit logical store → 2× global_store_dwordx4.
// See ld256 comment for ISA details (verified on MI350X, ROCm 7.2).
uint32_t* dst = reinterpret_cast<uint32_t*>(ptr);
dst[0] = val.d[0];
dst[1] = val.d[1];
dst[2] = val.d[2];
dst[3] = val.d[3];
dst[4] = val.d[4];
dst[5] = val.d[5];
dst[6] = val.d[6];
dst[7] = val.d[7];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it’s also worth leaving a TODO to double-check, once the MI355X hardware is available, that the compiler automatically translates this into the corresponding global_store_dwordx8 instruction, and not two sequential global_store_dwordx4?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion — I'll add a TODO comment in the code.

To verify on MI355X once hardware is available:

hipcc -O3 --offload-arch=gfx950 --save-temps concat_mla_q.cu
# Check the .s file for:
#   global_store_dwordx8  (single 256-bit NT store)
# vs:
#   global_store_dwordx4 × 2  (two 128-bit NT stores)

If the compiler doesn't emit dwordx8, we can fall back to inline ASM with global_store_dwordx8 directly. Pushing a commit with the TODO now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: verified on MI350X (gfx950, ROCm 7.2)!

The compiler generates global_store_dwordx4, not a single global_store_dwordx8. The gfx950 ISA in ROCm 7.2 / LLVM 22.0 does not include vector dwordx8 instructions at all (llvm-mc rejects them as invalid).

The 256-bit path still halves loop iterations, so the optimization is valid — just not via the single-instruction path we hoped for. Will add a TODO comment noting this for future ROCm versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this! Would it be better to disable the 256b memory instructions for now, then? Maybe with an assert "Not supported yet"?

Two consecutive 128b operations issued in different cycles would lead to uncoalesced memory accesses within a warp.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call — you're right that two consecutive 128b ops in different cycles would lead to uncoalesced accesses. I've disabled the gfx950 256-bit path for now (commented out with explanation).

The gfx942 (MI300X) non-temporal hints path remains unchanged — that one is hardware-verified and the compiler auto-coalesces into single global_store_dwordx4 nt instructions.

Also rebased on upstream/main to resolve the merge conflict.

When a future ROCm version adds native global_load/store_dwordx8, we can re-enable the 256-bit path with a simple uncomment + re-verify.

#else
asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n"
:
: "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]),
"r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]),
"r"(val.d[7])
: "memory");
#endif
#else
assert(false && "st256 requires SM100+ with CUDA 12.9+");
assert(false && "st256 requires SM100+ with CUDA 12.9+ or ROCm gfx950+");
#endif
}

Expand Down Expand Up @@ -185,29 +232,58 @@ __device__ __forceinline__ void st128(T& val, T* ptr) {
*reinterpret_cast<int4*>(ptr) = *reinterpret_cast<int4*>(&val);
}

// 256-bit cache-streaming (.cs) load / store — SM100+ only.
// 256-bit cache-streaming (.cs) load / store.
// SM100+: PTX .cs hint. ROCm gfx950+: slc (system-level coherent) hint.
__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) {
#if VLLM_256B_PTX_ENABLED
u32x8_t val;
#if defined(USE_ROCM) && defined(__gfx950__)
// gfx950 (CDNA4): 256-bit non-temporal load for streaming.
// The compiler coalesces sequential __builtin_nontemporal_load calls
// into vectorized global_load instructions at -O3 (verified on gfx942
// where 4× scalar NT → single global_load_dwordx4 nt).
// On gfx950: emits 2× global_load_dwordx4 nt (no vector dwordx8 in ISA).
// Verified on MI350X, ROCm 7.2 / LLVM 22.0.
// TODO: Check if future ROCm versions add a vector dwordx8 instruction to the
// gfx950 ISA.
const uint32_t* src = reinterpret_cast<const uint32_t*>(addr);
for (int i = 0; i < 8; i++) {
val.d[i] =
__builtin_nontemporal_load(reinterpret_cast<const int*>(src) + i);
}
Comment on lines +250 to +253
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The comment here states the goal is to use global_load_dwordx8, but the implementation uses a loop of 8 scalar __builtin_nontemporal_load calls. This will likely prevent the compiler from generating a single 256-bit vector load instruction, resulting in 8 separate 32-bit memory transactions instead of one. This de-vectorization will likely cause a significant performance degradation on gfx950, negating the benefit of the wider vector path.

Since this path is not yet hardware-tested, it's critical to ensure it's implemented efficiently. A true 256-bit non-temporal load, likely requiring inline GCN assembly to emit a global_load_dwordx8 with the slc or nt flag, should be used to achieve the expected performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — the gfx950 path is not yet hardware-tested. However, based on ISA verification on gfx942, the compiler does coalesce sequential __builtin_nontemporal_load/store calls into vectorized instructions at -O3.

For the gfx950 256-bit path specifically: we expect the compiler to coalesce 8 × __builtin_nontemporal_load(int) into global_load_dwordx4 pairs (or ideally global_load_dwordx8 if the compiler recognizes the pattern). Since we cannot verify the gfx950 ISA output without MI355X hardware or a ROCm toolchain with gfx950 support, I will add a TODO comment noting that this should be verified when gfx950 hardware becomes available.

If the compiler does not coalesce to dwordx8 on gfx950, we can switch to inline ASM at that point. The __builtin_nontemporal approach was chosen because inline ASM with vector register constraints (v4i32/v8i32) causes invalid operand for instruction errors in hipcc when used across multiple template instantiations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we expect the compiler to coalesce 8 × __builtin_nontemporal_load(int) into global_load_dwordx4 pairs

Just out of curiosity: any reason why this is not automatically translated to one global_load_dwordx8 if the hardware supports it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! The reason is hardware ISA generation differences:

  • gfx942 (CDNA3 / MI300X): The maximum vector load width is 128-bit (global_load_dwordx4). There is no global_load_dwordx8 instruction in the gfx942 ISA. So 8× scalar loads → 2× global_load_dwordx4 is the best the compiler can do.

  • gfx950 (CDNA4 / MI355X): The ISA does include global_load_dwordx8 (256-bit vector load). We expect the compiler to coalesce 8× scalar loads into a single global_load_dwordx8 on this target — but this needs hardware verification (added TODO).

This 128-bit vs 256-bit vector width is also the main reason MI300X peaks at ~4.5x speedup vs B300's 11.8x for this kernel. MI355X should close that gap significantly.

Reference: AMD CDNA4 ISA (gfx950) — see GLOBAL_LOAD_DWORDX8 instruction.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we have the answer from hardware testing! 🔬

The reason is simple: global_load/store_dwordx8 does not exist in the gfx950 ISA (as of ROCm 7.2 / AMD LLVM 22.0). The maximum vector load width is still 128-bit (dwordx4), same as gfx942.

Only s_buffer_load_dwordx8 (scalar buffer load) exists — there are no 256-bit vector load/store instructions.

So the compiler is doing the best it can: decomposing 256-bit into 2× dwordx4 with adjacent offsets (off and off offset:16). It's possible a future ROCm release could add vector dwordx8 support, but for now this is a hardware/ISA limitation, not a compiler limitation.

#else
asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];"
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
: "l"(addr));
#endif
return val;
#else
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+");
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+ or ROCm gfx950+");
return u32x8_t{};
#endif
}

__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) {
#if VLLM_256B_PTX_ENABLED
#if defined(USE_ROCM) && defined(__gfx950__)
// gfx950 (CDNA4): 256-bit non-temporal store for streaming write.
// Emits 2× global_store_dwordx4 nt (no vector dwordx8 in gfx950 ISA).
// Verified on MI350X, ROCm 7.2 / LLVM 22.0.
// TODO: Check if future ROCm versions add a vector dwordx8 instruction to the
// gfx950 ISA.
int* dst = reinterpret_cast<int*>(addr);
for (int i = 0; i < 8; i++) {
__builtin_nontemporal_store(static_cast<int>(val.d[i]), dst + i);
}
Comment on lines +276 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to ld256_cs, this implementation uses a loop of 8 scalar __builtin_nontemporal_store calls instead of a single vectorized 256-bit store. This is unlikely to be compiled into a global_store_dwordx8 instruction and will likely result in 8 separate 32-bit memory transactions. This approach negates the performance benefit of using 256-bit vectors on gfx950.

To achieve the expected performance from doubling the vector width, a true vectorized non-temporal store should be implemented, which may require using inline GCN assembly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reasoning as above — on gfx942, we verified that __builtin_nontemporal_store is coalesced into global_store_dwordx4 nt by the compiler. The gfx950 path should similarly benefit from compiler vectorization.

Will add a TODO comment for gfx950 ISA verification once hardware is available. If dwordx8 coalescing does not happen automatically, inline ASM can be added as a follow-up.

#else
asm volatile(
"st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr),
"r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]),
"r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]));
#endif
#else
assert(false && "st256_cs requires SM100+ with CUDA 12.9+");
assert(false && "st256_cs requires SM100+ with CUDA 12.9+ or ROCm gfx950+");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove/improve "or ROCm gfx950+" message in all the asserts, since 256-bit path is disabled for now on gfx950

#endif
}

Expand All @@ -217,7 +293,8 @@ __device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); }
__device__ __forceinline__ void st32(int* addr, int val) { *addr = val; }

// 32-bit cache-streaming (.cs) load / store.
// Falls back to ld32/st32 on ROCm (no .cs hint).
// ROCm gfx942+: uses __builtin_nontemporal_store for write-side bypass.
// Falls back to ld32/st32 on other ROCm targets.
__forceinline__ __device__ int ld32_cs(const int* addr) {
int val;
#ifndef USE_ROCM
Expand All @@ -231,13 +308,18 @@ __forceinline__ __device__ int ld32_cs(const int* addr) {
__forceinline__ __device__ void st32_cs(int* addr, int val) {
#ifndef USE_ROCM
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
#elif VLLM_ROCM_USE_NT_HINTS
__builtin_nontemporal_store(val, addr);
#else
st32(addr, val);
#endif
}

// 128-bit cache-streaming (.cs) load / store.
// Falls back to ld128/st128 on ROCm (no .cs hint).
// ROCm gfx942+: uses __builtin_nontemporal_store for write-side cache
// bypass. Reads use normal loads to benefit from L2 cache on small/medium
// token counts (NT load hurts latency for batches < 2048).
// Falls back to ld128/st128 on other ROCm targets.
__forceinline__ __device__ int4 ld128_cs(const int4* addr) {
int4 val;
#ifndef USE_ROCM
Expand All @@ -254,6 +336,12 @@ __forceinline__ __device__ void st128_cs(int4* addr, int4 val) {
#ifndef USE_ROCM
asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr),
"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
#elif VLLM_ROCM_USE_NT_HINTS
int* dst = reinterpret_cast<int*>(addr);
__builtin_nontemporal_store(val.x, dst);
__builtin_nontemporal_store(val.y, dst + 1);
__builtin_nontemporal_store(val.z, dst + 2);
__builtin_nontemporal_store(val.w, dst + 3);
Comment on lines +340 to +344
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation uses four separate 32-bit non-temporal stores for a 128-bit operation. While the benchmarks show a performance gain from cache bypassing, this approach de-vectorizes the store operation, which may limit the performance uplift. A single 128-bit vectorized non-temporal store (global_store_dwordx4 with the nt flag) would be more efficient as it would avoid issuing four separate memory transactions.

For even better performance, consider using inline GCN assembly to emit a true vectorized non-temporal store. This would combine the benefits of both vectorization and cache bypassing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified on MI300X — the compiler auto-coalesces the four scalar __builtin_nontemporal_store into a single global_store_dwordx4 ... nt instruction.

Actual ISA from hipcc -O3 --offload-arch=gfx942 --save-temps for ConcatMLAQKernel<bf16, 512>:

global_load_dwordx4  v[0:3],   v[12:13], off
global_load_dwordx4  v[4:7],   v[12:13], off offset:512
global_store_dwordx4 v[10:11], v[0:3],   off nt              ; single 128-bit NT store
global_store_dwordx4 v[10:11], v[4:7],   off offset:512 nt   ; single 128-bit NT store
global_load_dword    v12,      v[8:9],   off
global_store_dword   v[0:1],   v12,      off offset:1024 nt  ; 32-bit NT store

No de-vectorization occurs — the compiler generates the same instruction count as the plain path, just with the nt flag added. We also tested inline ASM with global_store_dwordx4 + v4i32 vector register constraints, but hipcc produces invalid operand for instruction errors in multi-template contexts. The builtin approach is both correct and portable.

#else
st128(val, addr);
#endif
Expand Down
Loading