Skip to content

feat: AMD Instinct MI300X + MI355X (gfx942/gfx950) ROCm support#61

Merged
TheTom merged 2 commits into
TheTom:feature/turboquant-kv-cachefrom
andyluo7:add-rocm-mi300x-support
Apr 9, 2026
Merged

feat: AMD Instinct MI300X + MI355X (gfx942/gfx950) ROCm support#61
TheTom merged 2 commits into
TheTom:feature/turboquant-kv-cachefrom
andyluo7:add-rocm-mi300x-support

Conversation

@andyluo7
Copy link
Copy Markdown

@andyluo7 andyluo7 commented Apr 7, 2026

Summary

TurboQuant KV cache compression works on AMD Instinct datacenter GPUs:

  • MI300X (gfx942): Zero code changes — compiles and runs via HIP translation
  • MI355X (gfx950): Adds CDNA4 arch defines (5 files, ~20 lines changed)

Code Changes (gfx950 only)

File Change
vendors/hip.h Add CDNA4 define for __gfx950__, include in CDNA family
common.cuh Add GGML_CUDA_CC_CDNA4 constant + IS_CDNA4 macro
mma.cuh Route CDNA4 to compatible MFMA (bf16_1k, i32x32, f32x4 — NOT xf32)
mmq.cuh Include CDNA4 in stream-k dispatch
common.cuh Exclude CDNA4 from CDNA3-specific e4m3_fnuz FP8 path

Test Results

MI300X (gfx942, ROCm 7.0.2)

KV Cache pp512 (tok/s) tg128 (tok/s) vs f16 prefill vs f16 decode
f16 24,453 ± 230 181.2 ± 2.0 baseline baseline
turbo3 ~25,200 ~160 +3% 88%
turbo4 25,427 ± 17 161.1 ± 0.2 +4% 89%

MI355X (gfx950, ROCm 7.0.1)

KV Cache pp512 (tok/s) tg128 (tok/s) vs f16 prefill vs f16 decode
f16+FA 40,013 ± 902 254.5 ± 1.0 baseline baseline
turbo3 39,140 ± 475 162.3 ± 0.1 98% 64%
turbo4 39,232 ± 508 214.1 ± 0.7 98% 84%

WHT Kernel Correctness

Roundtrip max error: 2.980232e-07 — PASS ✅

Known Issues

  • MI355X non-FA MMQ crashes (xf32 MFMA unsupported on gfx950). TurboQuant types force FA and work correctly.
  • This is a pre-existing upstream llama.cpp issue, not TurboQuant-specific.

Environment

  • MI300X: gfx942, 192 GB HBM3, ROCm 7.0.2
  • MI355X: gfx950, 288 GB HBM3e, ROCm 7.0.1
  • Model: Qwen2.5-1.5B-Instruct Q4_K_M

TurboQuant KV cache compression (turbo2/turbo3/turbo4) builds and runs
correctly on AMD Instinct MI300X with ROCm 7.0.2. Zero code changes
required — existing CUDA kernels compile via HIP translation.

Test results (Qwen2.5-1.5B Q4_K_M, single MI300X):
- WHT roundtrip: PASS (max error 2.98e-07)
- turbo3 prefill: +3% vs f16 (25,200 vs 24,453 tok/s)
- turbo3 decode: 88% of f16 (160 vs 181 tok/s)
- turbo4 prefill: +4% vs f16 (25,427 vs 24,453 tok/s)
- turbo4 decode: 89% of f16 (161 vs 181 tok/s)

MI355X (gfx950) compiles but needs gfx950 added to llama.cpp's
MMQ kernel dispatch (upstream issue, not TurboQuant-specific).

Tested-by: Andy Luo <andyluo7@users.noreply.github.com>
@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Apr 7, 2026
Add AMD Instinct MI355X (gfx950) architecture support:

Code changes:
- vendors/hip.h: Add CDNA4 define for __gfx950__, include in CDNA family
- common.cuh: Add GGML_CUDA_CC_CDNA4 constant and IS_CDNA4 macro
- mma.cuh: Route CDNA4 to compatible MFMA instructions
  * bf16: mfma_f32_16x16x16bf16_1k (same as CDNA3)
  * int8: mfma_i32_16x16x32_i8 (same as CDNA3)
  * f32: mfma_f32_16x16x4f32 (CDNA2 path, NOT xf32 which doesn't exist on gfx950)
- mmq.cuh: Include CDNA4 in stream-k dispatch
- common.cuh: Exclude CDNA4 from CDNA3-specific e4m3_fnuz FP8 path (gfx950 uses standard e4m3fn)

MI355X test results (Qwen2.5-1.5B Q4_K_M, single GPU):
- turbo3: 39,140 tok/s prefill (98% of f16), 162 tok/s decode (64%)
- turbo4: 39,232 tok/s prefill (98% of f16), 214 tok/s decode (84%)
- WHT roundtrip: PASS (max error 2.98e-07)

Note: non-FA MMQ path crashes on gfx950 (xf32 MFMA unsupported).
TurboQuant types force FA and work correctly.

Tested-by: Andy Luo <andyluo7@users.noreply.github.com>
@andyluo7 andyluo7 changed the title docs: AMD Instinct MI300X (gfx942) ROCm test results feat: AMD Instinct MI300X + MI355X (gfx942/gfx950) ROCm support Apr 7, 2026
@TheTom TheTom force-pushed the feature/turboquant-kv-cache branch from 10cb187 to 0d6b38a Compare April 8, 2026 23:49
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Apr 9, 2026

Thanks for the added AMD Instinct support! Built locally on M5 Max (Metal) for posterity — clean build, binary loads, no regressions on the Mac side. The CDNA4 changes are properly gated behind __gfx950__ so zero impact on Metal/CUDA paths. Merging.

@TheTom TheTom merged commit 157cb85 into TheTom:feature/turboquant-kv-cache Apr 9, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ggml Nvidia GPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants