Skip to content

vulkan: chunked parallel kernel for GATED_DELTA_NET#20377

Draft
ProgenyAlpha wants to merge 14 commits intoggml-org:masterfrom
ProgenyAlpha:vulkan-gdn-chunked
Draft

vulkan: chunked parallel kernel for GATED_DELTA_NET#20377
ProgenyAlpha wants to merge 14 commits intoggml-org:masterfrom
ProgenyAlpha:vulkan-gdn-chunked

Conversation

@ProgenyAlpha
Copy link
Contributor

Follow-up to #20334. Adds the chunked parallel kernel infrastructure for Vulkan GATED_DELTA_NET, split out per @0cc4m's review feedback.

Depends on #20334 and #20340

Three new compute shaders implementing the chunked algorithm:

  • gated_delta_net_chunk_intra.comp — intra-chunk parallel computation
  • gated_delta_net_chunk_inter.comp — inter-chunk state propagation
  • gated_delta_net_chunk_output.comp — output reconstruction

Includes the rq1neq1 broadcast fix to match #20340's interleaved Q/K layout (head_id % neq1 instead of head_id / rq1).

Chunked dispatch is currently disabled (GDN_CHUNK_THRESHOLD = UINT32_MAX) — the autoregressive path handles all token counts. Enabling it will need cooperative matrix support for the output kernel to be competitive.

16/16 backend-ops tests passing (includes chunked-specific test configs with n_seq_tokens=64/128).

890M benchmarks (Qwen3-Coder-Next REAM Q4_K_M):

Metric Base (#20334) With #20340 + chunked infra Change
PP-512 165.31 t/s 215.46 t/s +30.3%
TG-128 21.16 t/s 21.68 t/s +2.5%

The PP improvement comes from #20340's chunked op path feeding our autoregressive shader more efficiently. The Vulkan chunked dispatch itself isn't active yet — that's the next optimization pass.

@github-actions github-actions bot added model Model specific testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Mar 11, 2026
@lemmi
Copy link

lemmi commented Mar 11, 2026

Benchmarks, Strix Halo:

master (e1a3999):

model size params backend ngl fa mmap dio test t/s
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1 0 1 pp2048 501.96 ± 4.54
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1 0 1 tg128 38.28 ± 0.02
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1 0 1 pp2048 701.72 ± 3.59
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1 0 1 tg128 44.07 ± 0.05
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1 0 1 pp2048 195.42 ± 1.47
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1 0 1 tg128 18.46 ± 0.03

PR (795f15c):

model size params backend ngl fa mmap dio test t/s
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1 0 1 pp2048 558.94 ± 6.59
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1 0 1 tg128 46.39 ± 0.04
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1 0 1 pp2048 839.17 ± 4.62
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1 0 1 tg128 53.54 ± 0.05
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1 0 1 pp2048 221.53 ± 2.29
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1 0 1 tg128 21.39 ± 0.02

That's 10-20% better PP performance, depending on the model.

@ProgenyAlpha
Copy link
Contributor Author

ProgenyAlpha commented Mar 11, 2026

@lemmi Great numbers, thanks for testing.

Updated the PR to actually enable the chunked Vulkan dispatch — it's now gated on shader core count (> 16 CUs) instead of being disabled. On my 890M (16 CUs) the 3-dispatch overhead makes chunked slower than autoregressive, so it stays off there. On your 8060S (32 CUs) it should activate automatically for n_tokens > 64 with d128 non-KDA configs.

I can't validate the chunked dispatch path myself since I only have the integrated 890M. If you get a chance to test the latest push, that would tell us whether the chunked shaders actually help PP on discrete hardware or if they need more work (coopmat for the output kernel is the next step if so).

@lemmi
Copy link

lemmi commented Mar 11, 2026

Small clarification: the 8060s is the iGPU on Strix Halo (aka Ryzen AI MAX+ 395). The 8060s has 40CUs.

Performance tanked with the latest patch:
Before:

model size params backend ngl fa mmap dio test t/s
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1 0 1 pp2048 564.42 ± 6.14
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1 0 1 pp2048 848.64 ± 1.26
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1 0 1 pp2048 229.15 ± 1.30

After:

model size params backend ngl fa mmap dio test t/s
qwen3next 80B.A3B Q4_K - Medium 46.20 GiB 79.67 B Vulkan 99 1 0 1 pp2048 366.17 ± 2.17
qwen35moe 35B.A3B Q8_0 34.36 GiB 34.66 B Vulkan 99 1 0 1 pp2048 490.98 ± 1.61
qwen35moe 122B.A10B Q5_K - Medium 85.60 GiB 122.11 B Vulkan 99 1 0 1 pp2048 165.62 ± 2.21

@ProgenyAlpha
Copy link
Contributor Author

@0cc4m Rebased on master. Chunked kernels work but the scalar output kernel is too slow without coopmat, so the threshold is disabled for now. I've got a coopmat output kernel already in the works, but do you want me to add it here or keep this as infrastructure and open a separate PR for the coopmat or stop here?

@jeffbolznv
Copy link
Contributor

Do I understand correctly that to see a gain you need to merge this PR with another?

What exact command line are you using where you see a 30% gain? I only see GDN taking about 5% of the time running llama-bench -m c:\models\Qwen3.5-35B-A3B-Q4_K_M.gguf -fa 1 -p 512 -n 128 --prio 1 -r 10. Can you share a GGML_VK_PERF_LOGGER log?

@ProgenyAlpha ProgenyAlpha marked this pull request as draft March 13, 2026 19:33
@ProgenyAlpha
Copy link
Contributor Author

@jeffbolznv Hey! This is noted in the PR description but the 30% PP gain comes from #20340's chunked op path on the graph side feeding my GDN vulkan autoregressive shader #20334 more efficiently, not from the Vulkan chunked shaders. Both #20334 and #20340 are already merged into master, so that improvement is already live.

The Vulkan chunked dispatch in this PR is actually disabled (GDN_CHUNK_THRESHOLD = UINT32_MAX) as the new shaders aren't running yet. They're infrastructure for the next step: a coopmat output kernel to make chunked competitive with autoregressive (I am almost there on my hardware within 5%).

So with this PR as-is, you'd see near identical performance to master since the chunked path doesn't activate. I was waiting to find out how 0cc4m would like to handle this or anyone in a position to give feedback.

I can close out this PR until I've done more thorough validation and testing and reopen then, if preferred.

@jeffbolznv
Copy link
Contributor

I mostly want to understand what kind of use case/benchmark you're accelerating on so I can see how much theoretical upside there is.

ProgenyAlpha and others added 7 commits March 13, 2026 21:45
Three-dispatch chunked pipeline for prompt processing acceleration:
intra-chunk WY decomposition, inter-chunk state propagation, output
combination. Currently disabled (threshold=UINT32_MAX).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add gated_delta_net_chunk_output_cm1.comp — a cooperative matrix variant
of the chunked output kernel that replaces the O(N²) scalar intra-chunk
loop with an f16 coopmat GEMM: A_decayed[64×64] @ vnew[64×128].

Kernel structure:
- Phase 1: Q@K^T via coopmat (unchanged from scalar variant)
- Phase 2a: Build causal decay mask → sh_adecay (f16, clamped)
- Phase 2b: Stage vnew into sh_kv (f16, pre-scaled by 1/√d)
- Pass 1: Inter-chunk Q@S → dst (scalar, 128 threads)
- Pass 2: Intra-chunk coopmat GEMM (full chunks) or scalar fallback
  (partial last chunk). 3 barriers total, 62.7KB shared memory.

Pipeline registered but not yet dispatched (threshold remains disabled).
Test tolerance bumped to 5e-3 for n_seq_tokens≥64 to account for f16
intermediate precision in the coopmat path.

16/16 backend tests pass.
Lower GDN_CHUNK_THRESHOLD from UINT32_MAX to 2 and prefer the coopmat
output pipeline (cm1) when available, falling back to the scalar variant.

PP-512: ~206 → ~210 t/s on Radeon 890M (RDNA3.5).
Comprehensive documentation for PR ggml-org#20377 covering architecture,
benchmarks, PPL validation, per-kernel timing, and scaling analysis.
Includes side-by-side autoregressive vs chunked comparison on 890M.
Merge the inter-chunk state propagation and output computation into a
single dispatch, reducing the chunked pipeline from 3 dispatches to 2.

State lives in registers across the sequential chunk loop. vnew is
computed in-kernel and passed to the coopmat GEMM via shared memory
(f16, packed with subgroup shuffles). This eliminates the VNew scratch
buffer (wu_size) and H_snapshots buffer (h_size) — ~786KB/head/seq
saved for PP-512.

Architecture per chunk:
  Step 1: Load K, Q, gcum → shared (all 256 threads)
  Step 2: Q@K^T coopmat → sh_attn (all 256 threads)
  Step 3: Decay mask + O_inter = Q@state → dst (parallel)
  Step 4: vnew = U - W@state → sh_kv (128 threads + k_gated assist)
  Step 5: O_intra = A_decayed @ vnew coopmat GEMM → dst
  Step 6: state = exp(decay) * state + delta

Shared memory: 63,744 / 65,536 bytes. 16/16 backend tests pass.
This reverts commit 08c355c01f3a298ef943216d4c55367a1c967286.
PR ggml-org#20443 removed redundant state transposes from the graph and updated
the autoregressive shader to use col*S_V+i (coalesced) instead of
i*S_V+col (strided). The chunked inter kernel was not updated, causing
uncoalesced state reads and a ~8% PP regression.

Fix state_in load and final_out write to match the new layout.
h_snapshots (h_out/h_in) are internal scratch and keep their existing
layout since inter and output kernels agree.

PP-512: 202 → 218 t/s. 16/16 tests pass.
@ProgenyAlpha
Copy link
Contributor Author

I mostly want to understand what kind of use case/benchmark you're accelerating on so I can see how much theoretical upside there is.

The autoregressive kernel dispatches one workgroup per attention head. For Qwen3-Next (n_head_kv=2), that's 2 workgroups per GDN layer.

On my 890M (16 CU), GDN is ~8% of PP-512 time. Everything else (MLP, matmuls, norms) saturates all 16 CUs. Chunked doesn't show any improvement here because there's nothing for the iGPU to give.

On a 7900 XTX (96 CU, ~960 GB/s), the non-GDN ops scale with both CU count and bandwidth. Roughly 10x faster than my shared DDR5. The GDN op also gets faster from bandwidth (~3x), but it doesn't scale with CU count, still 2 workgroups, 94 CUs idle.

Dirty math (Amdahl's law, all estimates)
890M (16 CU) 7900 XTX (96 CU, est.)
Non-GDN 92 units ~9 units
GDN (auto) 8 units ~3 units (bandwidth helps, CU count doesn't)
Total 100 ~12
GDN % of pipeline 8% ~25%

GDN's share grows from ~8% to ~25% of the pipeline. Chunked dispatches 16 workgroups instead of 2 for PP-512. If chunking allowed something like a ~4× improvement on the GDN portion, the rough Amdahl math would put total PP around ~9.75 vs ~12 (~19%). Obviously that depends heavily on whether the kernel actually scales that way.

These are rough numbers. The bandwidth scaling on the GDN op, the actual compute vs memory bound split, the dispatch overhead of 3 stages, all of that needs real profiling data to pin down.

Based on this rough Amdahl model, GDN's relative share could grow on larger GPUs where the rest of the pipeline scales with CU count but the autoregressive kernel remains limited to a small number of workgroups. I can't prove the exact crossover point locally on 16 CUs, but the theoretical upside on larger GPUs makes it seem worth exploring.

Remove verbose algorithm comments, section dividers, stale inline
constant annotations, and unused extensions. Match llama.cpp codebase
style (minimal comments, no section decorators).

No functional changes. 16/16 tests pass.
Load both s_w and s_kg before the first barrier instead of using
separate barriers for each. Reduces per-token barriers from 3 to 2,
eliminating 64 barriers per chunk.

GDN per-op: 6818 → 5205 µs (-23.6%). 16/16 tests pass.
Remove unnecessary barrier after A-matrix dot product writes. Each
thread writes only to its own row; s_A isn't read cross-thread until
forward substitution. Cuts A-matrix barriers from 128 to 65 (one
per broadcast + one before forward sub).

Pad s_A stride from 64 to 65 to eliminate bank conflicts in the W/U
accumulation phase where all active threads read A(tid, j) with the
same j value.

GDN per-op: 5205 → 5136 µs. Combined with inter fusion: 6818 → 5136 µs
(-24.7%). 16/16 tests pass.
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 15, 2026
Intra:
- Strip all section/inline comments to match codebase style
- Add [[unroll]] to fixed-bound loops (A-matrix zero, W/U tile init/write)
- Guard chunk_len==0 underflow on s_decay[chunk_len-1]

Inter:
- Strip final comment

No functional changes. 16/16 tests pass.
- Raise GDN_CHUNK_THRESHOLD from 2 to CHUNK_SIZE (64). Chunked path
  only activates when there's at least one full chunk. Below that,
  autoregressive is faster and the 3-dispatch overhead isn't justified.
- Add maxStorageBufferRange guard on scratch allocation. Falls back to
  autoregressive if the scratch buffers would exceed device limits.
- Fix inaccurate shared memory stride comment in cm1 output kernel.

16/16 tests pass.
@ProgenyAlpha
Copy link
Contributor Author

ProgenyAlpha commented Mar 15, 2026

Fresh data after several rounds of optimization.

I spent the last couple of days trying to make the chunked kernel more efficient and reduce the per-op GDN. Today, I think I've finally broken even reducing it from 6818 µs to 5136 µs (-24.7%) after fusing inter kernel broadcasts and removing unnecessary barriers in the intra kernel. Chunked is now as efficent as autoregressive on 16 CU across both models tested.

Per-kernel breakdown after optimization (890M 16 CU, Qwen3-Next Q4_K_M, PP-512)
  • Intra (WY decomposition): 42%
  • Inter (state propagation): 27%
  • Output (coopmat GEMM): 31%
Throughput benchmarks (back-to-back, same build)

Qwen3-Next REAM Q4_K_M (36 GDN layers, n_head_kv=2):

Test Autoregressive Chunked coopmat
PP-512 226 t/s 223 t/s
PP-1024 226 t/s 232 t/s
PP-2048 224 t/s 226 t/s
PP-4096 215 t/s 221 t/s
TG-128 20.4 t/s 21.3 t/s

Qwen3.5-35B-A3B Q4_K_M (60 GDN layers, n_head_kv=2):

Test Autoregressive Chunked coopmat
PP-512 320 t/s 323 t/s
PP-2048 318 t/s 315 t/s
PP-4096 308 t/s 308 t/s
TG-128 20.6 t/s 21.0 t/s
PPL validation (back-to-back, Qwen3-Next, WikiText-2)
Context Master (auto) Chunked coopmat
512 15.79 ± 1.03 15.81 ± 1.03
4096 9.80 ± 0.21 9.84 ± 0.21

Lossless.

What changed since last push
  • Inter kernel: fused w + k_gated broadcasts into single barrier (3→2 per token)
  • Intra kernel: removed unnecessary post-dot-product barrier, padded s_A stride 64→65 for bank conflicts
  • Fixed chunked inter kernel state layout to match PR graph : remove redundant GDN state transposes #20443's transpose removal

I think I'm going to reach out for testers on higher CU count hardware (8060S, 7900 XTX, 9070) to see how the coopmat GEMMs scale.

@github-actions github-actions bot added the script Script related label Mar 15, 2026
@lemmi
Copy link

lemmi commented Mar 15, 2026

Strix Halo (8060s, 40CU). TG is basically unaffected, PP is taking a measurable hit:

pp2048:

model n_ubatch master t/s pr20377 t/s
qwen3next 80B.A3B Q4_K - Medium 256 462.69 ± 5.90 438.23 ± 5.35
qwen3next 80B.A3B Q4_K - Medium 512 579.91 ± 5.42 544.61 ± 5.62
qwen3next 80B.A3B Q4_K - Medium 1024 657.74 ± 4.32 588.84 ± 3.58
qwen3next 80B.A3B Q4_K - Medium 2048 564.90 ± 5.59 536.66 ± 10.92
qwen35moe 35B.A3B Q8_0 256 715.88 ± 13.33 675.13 ± 11.41
qwen35moe 35B.A3B Q8_0 512 897.60 ± 4.59 829.42 ± 6.89
qwen35moe 35B.A3B Q8_0 1024 918.49 ± 1.39 873.19 ± 4.74
qwen35moe 35B.A3B Q8_0 2048 842.82 ± 7.75 738.21 ± 14.07

@ProgenyAlpha
Copy link
Contributor Author

@lemmi I'm working on finalizing a script to test, do you know which commit you ran with the above results?

@lemmi
Copy link

lemmi commented Mar 15, 2026

@ProgenyAlpha 88c0296

@randomm
Copy link

randomm commented Mar 15, 2026

Benchmark results from AMD Strix Halo (Radeon 8060S, RDNA 3.5, 40 CUs, 128 GB unified LPDDR5X-8000).

Baseline: pre-built toolbox binary (b8334, includes #20334 + #20340)
PR build: native build from this branch (c671565), -ngl 999 -fa 1 -mmp 0 -p 512 -n 128

Note: baseline is a toolbox container build, PR is compiled natively — non-GDN regressions likely reflect compiler/flag differences rather than PR changes. Will follow up with a native master build for an apples-to-apples comparison.

Qwen3.5-122B-A10B (MoE + GDN, Q3_K_XL, 53 GiB)

Build pp512 (t/s) tg128 (t/s)
Baseline (b8334) 291.49 ± 5.03 23.69 ± 0.62
PR #20377 302.86 ± 1.66 24.17 ± 0.01
Delta +3.9% +2.0%

GPT-OSS-120B (standard MoE, no GDN — control, Q4_K_XL, 59 GiB)

Build pp512 (t/s) tg128 (t/s)
Baseline (b8334) 649.21 ± 5.35 60.15 ± 0.04
PR #20377 603.41 ± 2.33 55.31 ± 0.20
Delta -7.1% -8.1%

Nemotron-3-Super-120B (Mamba-2 + MoE, Q4_K_XL, 78 GiB)

Build pp512 (t/s) tg128 (t/s)
Baseline (b8334) 199.00 ± 0.61 13.53 ± 0.09
PR #20377 189.39 ± 0.92 13.12 ± 0.11
Delta -4.8% -3.0%

Summary

  • GDN model (Qwen3.5) shows a modest improvement from the broadcast fix and infrastructure changes.
  • Chunked dispatch is disabled (GDN_CHUNK_THRESHOLD = UINT32_MAX), so the chunked shaders aren't active yet.
  • Non-GDN models show small regressions vs the toolbox baseline — likely a build environment difference, not a PR regression.

Happy to test with the chunked dispatch enabled or run GGML_VK_PERF_LOGGER profiles if that would help.

@randomm
Copy link

randomm commented Mar 15, 2026

Follow-up with native master build (89d0aec, b8357) for an apples-to-apples comparison. Same hardware (8060S, 40 CUs), same flags (-ngl 999 -fa 1 -mmp 0 -p 512 -n 128 -r 3).

Qwen3.5-122B-A10B (MoE + GDN, Q3_K_XL, 53 GiB)

Build pp512 (t/s) tg128 (t/s)
Master (89d0aec) 318.18 ± 1.42 24.45 ± 0.02
PR #20377 (c671565) 302.86 ± 1.66 24.17 ± 0.01
Delta -4.8% -1.1%

GPT-OSS-120B (standard MoE, no GDN, Q4_K_XL, 59 GiB)

Build pp512 (t/s) tg128 (t/s)
Master (89d0aec) 636.04 ± 1.74 59.37 ± 0.05
PR #20377 (c671565) 603.41 ± 2.33 55.31 ± 0.20
Delta -5.1% -6.8%

Nemotron-3-Super-120B (Mamba-2 + MoE, Q4_K_XL, 78 GiB)

Build pp512 (t/s) tg128 (t/s)
Master (89d0aec) 199.71 ± 0.95 13.82 ± 0.05
PR #20377 (c671565) 189.39 ± 0.92 13.12 ± 0.11
Delta -5.2% -5.1%

Analysis

With a native-to-native comparison, PR #20377 shows a ~5% regression across all models on Strix Halo (40 CUs), including non-GDN models. This suggests the regression isn't GDN-specific — it might be in shared shader infrastructure or pipeline setup code.

The earlier toolbox-vs-PR comparison masked this by attributing the gap to build environment differences, but now we can see it's a real regression on this hardware. Worth investigating whether it's specific to RDNA 3.5 / 40 CU configs.

@0cc4m
Copy link
Contributor

0cc4m commented Mar 15, 2026

It suggests you ran the tests after each other and so the PR was throttling due to thermal constraints.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs script Script related testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants