Skip to content

vulkan: fix turbo3 build + coopmat FA after April upstream sync#87

Merged
TheTom merged 1 commit into
TheTom:feature/turboquant-kv-cachefrom
apollosenvy:pr/vulkan-turbo3-april-fix
Apr 18, 2026
Merged

vulkan: fix turbo3 build + coopmat FA after April upstream sync#87
TheTom merged 1 commit into
TheTom:feature/turboquant-kv-cachefrom
apollosenvy:pr/vulkan-turbo3-april-fix

Conversation

@apollosenvy
Copy link
Copy Markdown

@apollosenvy apollosenvy commented Apr 18, 2026

After the April upstream sync (PR #80) rebased feature/turboquant-kv-cache onto ggml-org master, the Vulkan turbo3 path no longer builds and no longer runs on coopmat-capable AMD devices. This lines up with issues #50, #64 and #81.

Two independent causes:

1. Build breakage. Upstream PR #21572 (commit 1f30ac0) moved fp16 RTE rounding to a runtime SPIR-V patch and dropped the _rte shader variants plus rte.glsl. Simon's turbo3 Vulkan work (#62, ff8bb73) was written against the pre-drop base and the rebase pasted them on top of each other, so the tree now has:

  • Dangling cpy_f32_X_rte_len / cpy_f32_X_rte_data references in ggml-vulkan.cpp
  • A two-arg SET_ROWS(itype, rte) macro invoked with one arg
  • #include "rte.glsl" in copy_to_quant.comp pointing at a deleted file
  • set_rows_X_rte shader-gen entries for a define (RTE16) that's no longer used
  • flash_attn_*_turbo3_0_*_int8 MMQ variants that fail to compile (no MMQ path for turbo3)

2. Runtime assertion on coopmat-capable devices. CREATE_FA only registered turbo3 for FA_SCALAR. The tuning heuristic picks FA_COOPMAT1 on most shapes when coopmat1_fa_support is true (RADV on 7900 XTX, for instance), so ggml_vk_flash_attn landed on an uninitialized pipeline (name="", initialized=0, wg_denoms={0,0,0}) and aborted on Br == pipeline->wg_denoms[0] on the first prefill. End-to-end llama-cli -fa on -ctk turbo3 -ctv turbo3 -dev Vulkan0 couldn't decode a single token.

Changes

ggml-vulkan.cpp:

  • Collapse if (device->float_controls_rte_fp16) around cpy_f32_quant into a single block (matches upstream post-1f30ac0ce).
  • Simplify SET_ROWS to one arg.
  • Add CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_COOPMAT1, _cm1) and the _cm2 counterpart alongside the other quant types.

copy_to_quant.comp: drop the rte.glsl include.

vulkan-shaders-gen.cpp: drop the set_rows_X_rte entries; skip the MMQ flash_attn variant when tname == "turbo3_0".

Verified

7900 XTX (gfx1100, RADV NAVI31, ROCm 7.2.1, Vulkan 1.4.341, spirv-headers 1.4.341.0):

  • Clean HIP + Vulkan build, no shader compile errors.
  • test-backend-ops -o SET_ROWS -b Vulkan0: 147/147
  • test-backend-ops -o FLASH_ATTN_EXT -b Vulkan0 -p type_KV=turbo3: 530 cases, all pass (was aborting on case 3 before).
  • test-backend-ops -o FLASH_ATTN_EXT -b ROCm0 -p type_KV=turbo3: still green.
  • llama-cli -ngl 99 -fa on -ctk turbo3 -ctv turbo3 -dev Vulkan0 on Qwen3-8B Q4_K_M: no more abort.

HIP backend bench (Qwen3.5-27B Q4_K_M, 7900 XTX at 1350 MCLK / 402W):

KV pp128 pp2048 tg32 tg128
F16 666.67 893.98 23.06 20.98
turbo3 664.52 867.40 21.45 20.13
turbo4 663.81 874.68 21.24 20.17

Not fixed here

The Vulkan turbo3 decode produces incoherent text on head_dim=128 models even with this fix applied (HIP on the same model is fine). Filed separately as #88.

Origin's April upstream-sync rebase interleaved two changes that left the
Vulkan turbo3 KV path broken:

  * ggml-org/llama.cpp upstream PR ggml-org#21572 (1f30ac0) moved fp16 RTE
    rounding to a runtime SPIR-V patch and dropped the _rte shader
    variants plus rte.glsl itself.
  * TheTom/llama-cpp-turboquant PR TheTom#62 (ff8bb73) added turbo3 KV
    support against a base that still had those variants.

After the rebase, the tree had dangling cpy_f32_*_rte_len / _data
references, a two-arg SET_ROWS macro called with one arg, a
#include "rte.glsl" in a shader whose header no longer exists, and
MMQ shader variants generated for turbo3_0 even though the flash_attn
MMQ path has no turbo3 code. The result was that ggml-vulkan.cpp
failed to compile on a clean checkout (spirv-headers + all of the
above) and the shader-gen emitted garbage variants.

Separately, turbo3 flash-attn pipelines were only wired up for
FA_SCALAR. On a coopmat-capable device (e.g. RADV on a 7900 XTX) the
tuning heuristic picks FA_COOPMAT1 for most shapes, which landed in
ggml_vk_flash_attn with an uninitialized pipeline (wg_denoms={0,0,0})
and tripped the Br == wg_denoms[0] assertion as soon as a prefill
ubatch was dispatched. End-to-end llama-cli on Vulkan + -ctk turbo3
aborted on the first real forward pass.

Changes:

  * Drop the if (float_controls_rte_fp16) / else branches around
    cpy_f32_quant pipeline creation and collapse SET_ROWS to a single
    variant, matching upstream post-1f30ac0ce.
  * Remove the #include "rte.glsl" from copy_to_quant.comp.
  * Skip the MMQ flash_attn shader variant for turbo3_0 in the shader
    generator (no MMQ code path for it).
  * Register CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_COOPMAT1, _cm1)
    and the _cm2 counterpart alongside the other quant types.

Verified on AMD 7900 XTX (gfx1100 / RADV NAVI31, ROCm 7.2.1 + Vulkan
1.4.341, spirv-headers 1.4.341.0):

  * Full HIP+Vulkan build is clean with no shader compile errors.
  * test-backend-ops -o SET_ROWS -b Vulkan0 : 147/147
  * test-backend-ops -o FLASH_ATTN_EXT -b Vulkan0 -p type_KV=turbo3 :
    530 cases pass (previously aborted on case 3).
  * test-backend-ops -o FLASH_ATTN_EXT -b ROCm0 -p type_KV=turbo3 :
    still green (no HIP regression).
  * llama-cli on Qwen3-8B Q4_K_M with -ngl 99 -fa on -ctk turbo3
    -ctv turbo3 on Vulkan0 no longer aborts. The remaining head_dim=128
    correctness issue on the Vulkan turbo3 decode path is pre-existing
    and orthogonal to this change.

llama-bench on Qwen3.5-27B Q4_K_M, 7900 XTX OC, HIP backend:

  F16     tg128=20.98   turbo3 tg128=20.13   turbo4 tg128=20.17

Refs: TheTom/llama-cpp-turboquant issues TheTom#50, TheTom#64, TheTom#81
@TheTom TheTom merged commit 627ebbc into TheTom:feature/turboquant-kv-cache Apr 18, 2026
23 of 50 checks passed
jimbothigpen pushed a commit to jimbothigpen/frankenturbo2 that referenced this pull request May 2, 2026
vulkan: fix turbo3 build + coopmat FA after April upstream sync
TheTom added a commit that referenced this pull request May 3, 2026
Mirror of @apollosenvy's turbo3_0 Vulkan SET_ROWS port (PR #33 + #87)
to the other two turbo types. Reported by @dpblnt in #50 with a clean
matrix on RX 9060 XT showing turbo3 V works on Vulkan but turbo2/turbo4
V abort with:

  pre-allocated tensor (cache_v_l*) in a buffer (Vulkan0)
  that cannot run the operation (SET_ROWS)

at llama_context::sched_reserve() time, before any compute runs.

Mechanical port across 4 files:

- vulkan-shaders/types.glsl: block_turbo2_0 + block_turbo4_0 struct
  declarations matching the C side (ggml-common.h).

- vulkan-shaders/copy_to_quant.comp: SET_ROWS quantize main() blocks
  for turbo2 (4 centroids, 2-bit pack, no signs byte) and turbo4
  (16 centroids, 4-bit nibble pack, no signs byte). WHT setup and
  reduction structure identical to turbo3 (QK = 128 across all three).
  Centroid + midpoint tables ported from CENTROIDS_2BIT and
  CENTROIDS_4BIT in ggml-turbo-quant.c.

- vulkan-shaders/vulkan-shaders-gen.cpp: turbo2_0 and turbo4_0 added
  to the set_rows iteration list at line ~789.

- ggml-vulkan.cpp: SET_ROWS pipeline registrations + supports_op
  switch + dispatch element-count all extended with TURBO2_0 and
  TURBO4_0 cases.

## Verified on llvmpipe Vulkan (CPU software, AMD MI300X cloud droplet)

Patched ggml-vulkan.cpp temporarily during repro to allow llvmpipe
(normally filtered out as eCpu); patch reverted before commit. The
SET_ROWS abort is a backend-capability check at graph build time so
it fires regardless of GPU vs CPU Vulkan backend.

| ctk / ctv         | tg16 (t/s) | status        |
|-------------------|-----------:|---------------|
| q4_0 / q4_0       | 17.68      | baseline      |
| q4_0 / turbo3     | 5.91       | already worked|
| q4_0 / turbo4     | 6.14       | was aborting  |
| q4_0 / turbo2     | 5.65       | was aborting  |

llvmpipe perf numbers are not meaningful (CPU-emulated Vulkan); they
are reported here only to confirm the abort is gone and the kernels
run end-to-end without divergence.

## Needs GPU validation

Cannot validate GPU shader correctness on the droplet (MI300X SR-IOV
VF does not expose itself to RADV/amdvlk on cloud). Specifically:
- Subgroup shuffle / ballot behavior on real GPU subgroup sizes
- Shader compilation under non-llvmpipe Vulkan drivers
- PPL / quality on the actual quantization math

@dpblnt @apollosenvy if either of you has cycles, would appreciate
a quick rebuild on RDNA Vulkan (gfx1100/gfx1200) to confirm:
1. The SET_ROWS abort that triggered #50 is gone
2. Output coherence on turbo4 V (not garbage tokens)
3. PPL stays in the expected ballpark vs the CUDA / Metal
   implementations of the same quants

Closes #50.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sbaier1 pushed a commit to sbaier1/llama-cpp-turboquant that referenced this pull request May 8, 2026
Mirror of @apollosenvy's turbo3_0 Vulkan SET_ROWS port (PR TheTom#33 + TheTom#87)
to the other two turbo types. Reported by @dpblnt in TheTom#50 with a clean
matrix on RX 9060 XT showing turbo3 V works on Vulkan but turbo2/turbo4
V abort with:

  pre-allocated tensor (cache_v_l*) in a buffer (Vulkan0)
  that cannot run the operation (SET_ROWS)

at llama_context::sched_reserve() time, before any compute runs.

Mechanical port across 4 files:

- vulkan-shaders/types.glsl: block_turbo2_0 + block_turbo4_0 struct
  declarations matching the C side (ggml-common.h).

- vulkan-shaders/copy_to_quant.comp: SET_ROWS quantize main() blocks
  for turbo2 (4 centroids, 2-bit pack, no signs byte) and turbo4
  (16 centroids, 4-bit nibble pack, no signs byte). WHT setup and
  reduction structure identical to turbo3 (QK = 128 across all three).
  Centroid + midpoint tables ported from CENTROIDS_2BIT and
  CENTROIDS_4BIT in ggml-turbo-quant.c.

- vulkan-shaders/vulkan-shaders-gen.cpp: turbo2_0 and turbo4_0 added
  to the set_rows iteration list at line ~789.

- ggml-vulkan.cpp: SET_ROWS pipeline registrations + supports_op
  switch + dispatch element-count all extended with TURBO2_0 and
  TURBO4_0 cases.

## Verified on llvmpipe Vulkan (CPU software, AMD MI300X cloud droplet)

Patched ggml-vulkan.cpp temporarily during repro to allow llvmpipe
(normally filtered out as eCpu); patch reverted before commit. The
SET_ROWS abort is a backend-capability check at graph build time so
it fires regardless of GPU vs CPU Vulkan backend.

| ctk / ctv         | tg16 (t/s) | status        |
|-------------------|-----------:|---------------|
| q4_0 / q4_0       | 17.68      | baseline      |
| q4_0 / turbo3     | 5.91       | already worked|
| q4_0 / turbo4     | 6.14       | was aborting  |
| q4_0 / turbo2     | 5.65       | was aborting  |

llvmpipe perf numbers are not meaningful (CPU-emulated Vulkan); they
are reported here only to confirm the abort is gone and the kernels
run end-to-end without divergence.

## Needs GPU validation

Cannot validate GPU shader correctness on the droplet (MI300X SR-IOV
VF does not expose itself to RADV/amdvlk on cloud). Specifically:
- Subgroup shuffle / ballot behavior on real GPU subgroup sizes
- Shader compilation under non-llvmpipe Vulkan drivers
- PPL / quality on the actual quantization math

@dpblnt @apollosenvy if either of you has cycles, would appreciate
a quick rebuild on RDNA Vulkan (gfx1100/gfx1200) to confirm:
1. The SET_ROWS abort that triggered TheTom#50 is gone
2. Output coherence on turbo4 V (not garbage tokens)
3. PPL stays in the expected ballpark vs the CUDA / Metal
   implementations of the same quants

Closes TheTom#50.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants