-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[ROCm] Optimize concat_mla_q for CDNA3 (MI300X) and CDNA4 (MI355X) #36743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,14 +17,30 @@ | |
|
|
||
| // Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which | ||
| // together enable 256-bit (v8.u32) PTX load/store instructions. | ||
| // Use for PTX instruction selection with architecture fallback paths. | ||
| // ROCm gfx950+ (CDNA4/MI350X/MI355X): 256-bit logical width via 2× dwordx4. | ||
| // Use for PTX/ISA instruction selection with architecture fallback paths. | ||
| #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ | ||
| defined(CUDA_VERSION) && CUDA_VERSION >= 12090 | ||
| #define VLLM_256B_PTX_ENABLED 1 | ||
| // gfx950 (CDNA4): 256-bit path disabled for now. The gfx950 ISA (ROCm 7.2, | ||
| // LLVM 22.0) lacks vector dwordx8 load/store instructions, so the compiler | ||
| // decomposes 256-bit ops into 2× dwordx4 issued in separate cycles, leading | ||
| // to uncoalesced memory accesses within a warp. Re-enable once a future | ||
| // ROCm version adds native global_load/store_dwordx8 support. | ||
| // See: https://github.com/vllm-project/vllm/pull/36743#discussion_r2048743213 | ||
| // #elif defined(USE_ROCM) && defined(__gfx950__) | ||
| // #define VLLM_256B_PTX_ENABLED 1 | ||
| #else | ||
| #define VLLM_256B_PTX_ENABLED 0 | ||
| #endif | ||
|
|
||
| // ROCm gfx942+ (CDNA3/MI300X): non-temporal hints for cache bypass. | ||
| #if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx950__)) | ||
| #define VLLM_ROCM_USE_NT_HINTS 1 | ||
| #else | ||
| #define VLLM_ROCM_USE_NT_HINTS 0 | ||
| #endif | ||
|
|
||
| namespace vllm { | ||
|
|
||
| // ============================================================ | ||
|
|
@@ -133,28 +149,59 @@ struct alignas(VecTraits<use_256b>::ARCH_MAX_VEC_SIZE) PackedVec { | |
| // Load / store primitives | ||
| // ============================================================ | ||
|
|
||
| // 256-bit load / store — SM100+ only (PTX v8 instructions). | ||
| // 256-bit load / store — SM100+ (PTX v8) or ROCm gfx950+ (2× dwordx4). | ||
| __device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { | ||
| #if VLLM_256B_PTX_ENABLED | ||
| #if defined(USE_ROCM) && defined(__gfx950__) | ||
| // gfx950 (CDNA4): 256-bit logical load. The gfx950 ISA does not have | ||
| // global_load_dwordx8 — the compiler emits 2× global_load_dwordx4 with | ||
| // adjacent offsets (off + off:16). This still halves loop iterations vs | ||
| // the 128-bit path. Verified on MI350X, ROCm 7.2 / LLVM 22.0. | ||
| // If a future ROCm adds vector dwordx8, this code benefits automatically. | ||
| const uint32_t* src = reinterpret_cast<const uint32_t*>(ptr); | ||
| val.d[0] = src[0]; | ||
| val.d[1] = src[1]; | ||
| val.d[2] = src[2]; | ||
| val.d[3] = src[3]; | ||
| val.d[4] = src[4]; | ||
| val.d[5] = src[5]; | ||
| val.d[6] = src[6]; | ||
| val.d[7] = src[7]; | ||
| #else | ||
| asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" | ||
| : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), | ||
| "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) | ||
| : "l"(ptr)); | ||
| #endif | ||
| #else | ||
| assert(false && "ld256 requires SM100+ with CUDA 12.9+"); | ||
| assert(false && "ld256 requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); | ||
| #endif | ||
| } | ||
|
|
||
| __device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { | ||
| #if VLLM_256B_PTX_ENABLED | ||
| #if defined(USE_ROCM) && defined(__gfx950__) | ||
| // gfx950 (CDNA4): 256-bit logical store → 2× global_store_dwordx4. | ||
| // See ld256 comment for ISA details (verified on MI350X, ROCm 7.2). | ||
| uint32_t* dst = reinterpret_cast<uint32_t*>(ptr); | ||
| dst[0] = val.d[0]; | ||
| dst[1] = val.d[1]; | ||
| dst[2] = val.d[2]; | ||
| dst[3] = val.d[3]; | ||
| dst[4] = val.d[4]; | ||
| dst[5] = val.d[5]; | ||
| dst[6] = val.d[6]; | ||
| dst[7] = val.d[7]; | ||
| #else | ||
| asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" | ||
| : | ||
| : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), | ||
| "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), | ||
| "r"(val.d[7]) | ||
| : "memory"); | ||
| #endif | ||
| #else | ||
| assert(false && "st256 requires SM100+ with CUDA 12.9+"); | ||
| assert(false && "st256 requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); | ||
| #endif | ||
| } | ||
|
|
||
|
|
@@ -185,29 +232,58 @@ __device__ __forceinline__ void st128(T& val, T* ptr) { | |
| *reinterpret_cast<int4*>(ptr) = *reinterpret_cast<int4*>(&val); | ||
| } | ||
|
|
||
| // 256-bit cache-streaming (.cs) load / store — SM100+ only. | ||
| // 256-bit cache-streaming (.cs) load / store. | ||
| // SM100+: PTX .cs hint. ROCm gfx950+: slc (system-level coherent) hint. | ||
| __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { | ||
| #if VLLM_256B_PTX_ENABLED | ||
| u32x8_t val; | ||
| #if defined(USE_ROCM) && defined(__gfx950__) | ||
| // gfx950 (CDNA4): 256-bit non-temporal load for streaming. | ||
| // The compiler coalesces sequential __builtin_nontemporal_load calls | ||
| // into vectorized global_load instructions at -O3 (verified on gfx942 | ||
| // where 4× scalar NT → single global_load_dwordx4 nt). | ||
| // On gfx950: emits 2× global_load_dwordx4 nt (no vector dwordx8 in ISA). | ||
| // Verified on MI350X, ROCm 7.2 / LLVM 22.0. | ||
| // TODO: Check if future ROCm versions add a vector dwordx8 instruction to the | ||
| // gfx950 ISA. | ||
| const uint32_t* src = reinterpret_cast<const uint32_t*>(addr); | ||
| for (int i = 0; i < 8; i++) { | ||
| val.d[i] = | ||
| __builtin_nontemporal_load(reinterpret_cast<const int*>(src) + i); | ||
| } | ||
|
Comment on lines
+250
to
+253
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment here states the goal is to Since this path is not yet hardware-tested, it's critical to ensure it's implemented efficiently. A true 256-bit non-temporal load, likely requiring inline GCN assembly to emit a
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point — the gfx950 path is not yet hardware-tested. However, based on ISA verification on gfx942, the compiler does coalesce sequential For the gfx950 256-bit path specifically: we expect the compiler to coalesce 8 × If the compiler does not coalesce to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Just out of curiosity: any reason why this is not automatically translated to one
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great question! The reason is hardware ISA generation differences:
This 128-bit vs 256-bit vector width is also the main reason MI300X peaks at ~4.5x speedup vs B300's 11.8x for this kernel. MI355X should close that gap significantly. Reference: AMD CDNA4 ISA (gfx950) — see
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now we have the answer from hardware testing! 🔬 The reason is simple: Only So the compiler is doing the best it can: decomposing 256-bit into 2× |
||
| #else | ||
| asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" | ||
| : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), | ||
| "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) | ||
| : "l"(addr)); | ||
| #endif | ||
| return val; | ||
| #else | ||
| assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); | ||
| assert(false && "ld256_cs requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); | ||
| return u32x8_t{}; | ||
| #endif | ||
| } | ||
|
|
||
| __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { | ||
| #if VLLM_256B_PTX_ENABLED | ||
| #if defined(USE_ROCM) && defined(__gfx950__) | ||
| // gfx950 (CDNA4): 256-bit non-temporal store for streaming write. | ||
| // Emits 2× global_store_dwordx4 nt (no vector dwordx8 in gfx950 ISA). | ||
| // Verified on MI350X, ROCm 7.2 / LLVM 22.0. | ||
| // TODO: Check if future ROCm versions add a vector dwordx8 instruction to the | ||
| // gfx950 ISA. | ||
| int* dst = reinterpret_cast<int*>(addr); | ||
| for (int i = 0; i < 8; i++) { | ||
| __builtin_nontemporal_store(static_cast<int>(val.d[i]), dst + i); | ||
| } | ||
|
Comment on lines
+276
to
+278
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to To achieve the expected performance from doubling the vector width, a true vectorized non-temporal store should be implemented, which may require using inline GCN assembly.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reasoning as above — on gfx942, we verified that Will add a TODO comment for gfx950 ISA verification once hardware is available. If |
||
| #else | ||
| asm volatile( | ||
| "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), | ||
| "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), | ||
| "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); | ||
| #endif | ||
| #else | ||
| assert(false && "st256_cs requires SM100+ with CUDA 12.9+"); | ||
| assert(false && "st256_cs requires SM100+ with CUDA 12.9+ or ROCm gfx950+"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove/improve "or ROCm gfx950+" message in all the |
||
| #endif | ||
| } | ||
|
|
||
|
|
@@ -217,7 +293,8 @@ __device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); } | |
| __device__ __forceinline__ void st32(int* addr, int val) { *addr = val; } | ||
|
|
||
| // 32-bit cache-streaming (.cs) load / store. | ||
| // Falls back to ld32/st32 on ROCm (no .cs hint). | ||
| // ROCm gfx942+: uses __builtin_nontemporal_store for write-side bypass. | ||
| // Falls back to ld32/st32 on other ROCm targets. | ||
| __forceinline__ __device__ int ld32_cs(const int* addr) { | ||
| int val; | ||
| #ifndef USE_ROCM | ||
|
|
@@ -231,13 +308,18 @@ __forceinline__ __device__ int ld32_cs(const int* addr) { | |
| __forceinline__ __device__ void st32_cs(int* addr, int val) { | ||
| #ifndef USE_ROCM | ||
| asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); | ||
| #elif VLLM_ROCM_USE_NT_HINTS | ||
| __builtin_nontemporal_store(val, addr); | ||
| #else | ||
| st32(addr, val); | ||
| #endif | ||
| } | ||
|
|
||
| // 128-bit cache-streaming (.cs) load / store. | ||
| // Falls back to ld128/st128 on ROCm (no .cs hint). | ||
| // ROCm gfx942+: uses __builtin_nontemporal_store for write-side cache | ||
| // bypass. Reads use normal loads to benefit from L2 cache on small/medium | ||
| // token counts (NT load hurts latency for batches < 2048). | ||
| // Falls back to ld128/st128 on other ROCm targets. | ||
| __forceinline__ __device__ int4 ld128_cs(const int4* addr) { | ||
| int4 val; | ||
| #ifndef USE_ROCM | ||
|
|
@@ -254,6 +336,12 @@ __forceinline__ __device__ void st128_cs(int4* addr, int4 val) { | |
| #ifndef USE_ROCM | ||
| asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), | ||
| "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); | ||
| #elif VLLM_ROCM_USE_NT_HINTS | ||
| int* dst = reinterpret_cast<int*>(addr); | ||
| __builtin_nontemporal_store(val.x, dst); | ||
| __builtin_nontemporal_store(val.y, dst + 1); | ||
| __builtin_nontemporal_store(val.z, dst + 2); | ||
| __builtin_nontemporal_store(val.w, dst + 3); | ||
|
Comment on lines
+340
to
+344
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation uses four separate 32-bit non-temporal stores for a 128-bit operation. While the benchmarks show a performance gain from cache bypassing, this approach de-vectorizes the store operation, which may limit the performance uplift. A single 128-bit vectorized non-temporal store ( For even better performance, consider using inline GCN assembly to emit a true vectorized non-temporal store. This would combine the benefits of both vectorization and cache bypassing.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verified on MI300X — the compiler auto-coalesces the four scalar Actual ISA from global_load_dwordx4 v[0:3], v[12:13], off
global_load_dwordx4 v[4:7], v[12:13], off offset:512
global_store_dwordx4 v[10:11], v[0:3], off nt ; single 128-bit NT store
global_store_dwordx4 v[10:11], v[4:7], off offset:512 nt ; single 128-bit NT store
global_load_dword v12, v[8:9], off
global_store_dword v[0:1], v12, off offset:1024 nt ; 32-bit NT storeNo de-vectorization occurs — the compiler generates the same instruction count as the plain path, just with the |
||
| #else | ||
| st128(val, addr); | ||
| #endif | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it’s also worth leaving a TODO to double-check, once the MI355X hardware is available, that the compiler automatically translates this into the corresponding
global_store_dwordx8instruction, and not two sequentialglobal_store_dwordx4?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion — I'll add a TODO comment in the code.
To verify on MI355X once hardware is available:
If the compiler doesn't emit
dwordx8, we can fall back to inline ASM withglobal_store_dwordx8directly. Pushing a commit with the TODO now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: verified on MI350X (gfx950, ROCm 7.2)!
The compiler generates 2×
global_store_dwordx4, not a singleglobal_store_dwordx8. The gfx950 ISA in ROCm 7.2 / LLVM 22.0 does not include vectordwordx8instructions at all (llvm-mcrejects them as invalid).The 256-bit path still halves loop iterations, so the optimization is valid — just not via the single-instruction path we hoped for. Will add a TODO comment noting this for future ROCm versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this! Would it be better to disable the 256b memory instructions for now, then? Maybe with an assert "Not supported yet"?
Two consecutive 128b operations issued in different cycles would lead to uncoalesced memory accesses within a warp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call — you're right that two consecutive 128b ops in different cycles would lead to uncoalesced accesses. I've disabled the gfx950 256-bit path for now (commented out with explanation).
The gfx942 (MI300X) non-temporal hints path remains unchanged — that one is hardware-verified and the compiler auto-coalesces into single
global_store_dwordx4 ntinstructions.Also rebased on upstream/main to resolve the merge conflict.
When a future ROCm version adds native
global_load/store_dwordx8, we can re-enable the 256-bit path with a simple uncomment + re-verify.