Skip to content

ggml: add GATED_DELTA_NET op#19504

Merged
am17an merged 5 commits intoggml-org:masterfrom
am17an:gated_delta_net
Mar 7, 2026
Merged

ggml: add GATED_DELTA_NET op#19504
am17an merged 5 commits intoggml-org:masterfrom
am17an:gated_delta_net

Conversation

@am17an
Copy link
Contributor

@am17an am17an commented Feb 11, 2026

Add CPU/CUDA impl for GATED_DELTA_NET used in qwen3next and a lot of upcoming recent attention models. This is a basic vector impl and not the chunking impl, although this should work for n_tokens > 1 as a reference implementation. I tested this vs build_delta_net_autoregressive and the results were good. I plan to add the chunked implementation for CPU and CUDA.

master:

model size params backend threads fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CPU 16 1 tg32 4.77 ± 0.03
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CPU 16 1 tg32 @ d1024 4.55 ± 0.13

sched_reserve: graph nodes = 14990 (with bs=512), 6242 (with bs=1)

ggml_op_gated_delta_net added to the qwen3next graph (not added in the PR)

model size params backend threads fa test t/s
qwen35moe ?B Q4_K - Small 18.55 GiB 34.66 B CPU 16 1 tg32 11.08 ± 0.20
qwen35moe ?B Q4_K - Small 18.55 GiB 34.66 B CPU 16 1 tg32 @ d1024 11.21 ± 0.07

sched_reserve: graph nodes = 14990 (with bs=512), 5342 (with bs=1)

@am17an am17an requested a review from ggerganov as a code owner February 11, 2026 07:09
@am17an am17an requested a review from pwilkin February 11, 2026 07:09
@github-actions github-actions bot added testing Everything test related ggml changes relating to the ggml tensor library for machine learning labels Feb 11, 2026
@ggerganov
Copy link
Member

ggerganov commented Feb 11, 2026

I think it is too early to implement the dedicated delta net ops. There are still many things to optimize in the existing implementation (you can keep track of my progress in #19375). After that we have to consolidate the KDA version of the delta net (#18792). Btw the l2 norm should not be part of this OP - fixed in my branch. Also not sure how to handle the 2 variants of this operator (autoregressive and chunked).

So I think we can experiment with a dedicated op in a branch, but merging this in master will likely take time.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

@ggerganov I defer to your judgement, my thinking was that qwen3.5 is already a major model series, so even if the op is just for that model it makes sense.

for KDA, AFAIK it the gate is a matrix, so it will just be another dot product instead of a scale. For chunk vs autoregressive, we have the vec FA path for CPU which now serves a reference kernel. I was thinking it would be the same here, the autoregressive kernel remains the simple kernel while chunking is the optimisation, both solve the same recurrence.

@ggerganov
Copy link
Member

Ok, let's prototype a branch that also has this op together with the CUDA implementation rebased on #19375. I will then add the Metal version of the kernel and from there we can consider a quicker merge if things are looking good. Also, want to see if having this op will allow the CUDA graphs to be more easily enabled.

@pwilkin
Copy link
Contributor

pwilkin commented Feb 11, 2026

So this is basically what the Transformers implementations have as the "recurrent" implementation, right? No chunking, just iterating over tokens.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

@pwilkin yes, just calculating the recurrence token by token

@ggerganov
Copy link
Member

ggerganov commented Feb 11, 2026

Btw, should also consider small batch sizes larger than 1 to be handled by this operator too. I'm not sure where the break-even point would be, but I imagine that processing a few tokens auto-regressively (i.e. more than 1 and less than ~16) would be more efficient compared to the chunking path. Also don't forget that dim 3 will handle separate sequences - though from a quick look, this implementation already accounts for that.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

Btw, should also consider small batch sizes larger than 1 to be handled by this operator too. I'm not sure where the break-even point would be, but I imagine that processing a few tokens auto-regressively (i.e. more than 1 and less than ~16) would be more efficient compared to the chunking path.

Yes for small amount of tokens we can just run a loop even in CUDA. I have not looked into the chunked impl yet, but I will invest some time in finding the breakeven point

Also don't forget that dim 3 will handle separate sequences - though from a quick look, this implementation already accounts for that.

I think this should be fine, the work is split among dim1 * dim3 (heads * sequences)

@ymcki
Copy link
Contributor

ymcki commented Feb 11, 2026

Great performance gain for inference. Looking forward to seeing your implementation done for the major backends.

If you have plan to do the chunking version as well, it will be great if it is based on the block implementation done at fla.

https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/chunk_intra_token_parallel.py
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/chunk_intra.py

Copy link
Contributor

@pwilkin pwilkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean to me. Are you planning on doing the chunking version here as well, or separate op / PR?

@ggerganov
Copy link
Member

Converted to draft since I am not sure if my comment was clear: #19504 (comment). First we will be prototyping a new branch and after that we will consider adding the new op.

@pwilkin
Copy link
Contributor

pwilkin commented Feb 11, 2026

Should we use this PR or will you create a dedicated branch?

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

@ggerganov I removed the norm, and also added the autoregressive cuda op in 01eda69, it passes test-backend-ops. I have not done the rebase on #19375

@github-actions github-actions bot added model Model specific Nvidia GPU Issues specific to Nvidia GPUs labels Feb 11, 2026
@ggerganov
Copy link
Member

ggerganov commented Feb 11, 2026

Just a heads up, I will be rebasing the #19375 branch from time to time. Hope it's not a big issue. Just always put your commits on top. I'm hoping to merge in a day or two.

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

I did a quick perf test this PR + #19375 + replacing the autoregressive for qwen3next with gated_delta_net. On a 5090

Details

master

model size params backend ngl fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 83.92 ± 0.39
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d1024 84.45 ± 0.36
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d2048 84.20 ± 0.61
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d4096 83.82 ± 0.56
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d8192 83.43 ± 1.73
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d16384 83.56 ± 0.47

PR:

model size params backend ngl fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 105.95 ± 0.36
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d1024 105.05 ± 0.91
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d2048 105.33 ± 0.42
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d4096 105.10 ± 0.50
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d8192 98.13 ± 1.79
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d16384 97.22 ± 0.49

@ggerganov
Copy link
Member

For reference, what do you get with CUDA graphs forced enabled:

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f3d8317e1..605cb3ed4 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2894,7 +2894,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
 #endif
         }
 
-        if (node->op == GGML_OP_ADD &&
+        if (false && node->op == GGML_OP_ADD &&
             node->src[1] && node->src[1]->ne[1] > 1 &&
             (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
             (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&

@am17an
Copy link
Contributor Author

am17an commented Feb 11, 2026

With force enabled CUDA graphs

Details
model size params backend ngl fa test t/s
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 111.89 ± 2.48
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d1024 135.26 ± 6.93
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d2048 135.89 ± 4.95
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d4096 134.77 ± 4.67
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d8192 123.30 ± 6.07
qwen3next 80B.A3B Q2_K - Medium 27.12 GiB 79.67 B CUDA 99 1 tg32 @ d16384 121.22 ± 5.28

@am17an
Copy link
Contributor Author

am17an commented Mar 7, 2026

I see actually a huge difference in PP on CPU when just using the autoregressive kernel instead of the current one i.e. just use the fused op regardless of n_tokens. But I think I will optimize this later

@am17an am17an merged commit c5a7788 into ggml-org:master Mar 7, 2026
76 of 78 checks passed
@am17an am17an deleted the gated_delta_net branch March 7, 2026 07:41
@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Mar 7, 2026

Great speedup on tg (Qwen Next and Qwen 3.5)!
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
g p g p

@jeffbolznv
Copy link
Contributor

Hi @ProgenyAlpha, just wanted to check whether you still plan to submit a PR for the vulkan backend support.

@CISC
Copy link
Member

CISC commented Mar 8, 2026

@am17an
Copy link
Contributor Author

am17an commented Mar 8, 2026

@CISC not sure who maintains the MUSA backend, but it seems like a compiler bug

arkavo-com added a commit to arkavo-ai/llama.cpp that referenced this pull request Mar 8, 2026
Add a fused Metal kernel for the gated delta net recurrence op
(ggml-org#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@ggerganov
Copy link
Member

@yeahdongcn PTAL at the MUSA issue above.

@am17an In the meantime we can change supports_op to return false for MUSA

@yeahdongcn
Copy link
Collaborator

@yeahdongcn PTAL at the MUSA issue above.

@am17an In the meantime we can change supports_op to return false for MUSA

No problem. I'll try a local build first and see if I should open an internal ticket. Thanks!

@ProgenyAlpha
Copy link
Contributor

Hi @ProgenyAlpha, just wanted to check whether you still plan to submit a PR for the vulkan backend support.

I wasn't sure where the thread was going so I wanted to let you guys cook and see how things unfolded before I jump back in. I'll rebase and work on that this week if I have time. Thanks for pinging me!

ggerganov pushed a commit that referenced this pull request Mar 10, 2026
Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 10, 2026
* ggml: add GATED_DELTA_NET op

* remove the transpose

* add KDA

* add qwen35 dense

* llama : check for fused gated delta net backend support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
ggerganov added a commit that referenced this pull request Mar 11, 2026
* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
ggerganov added a commit that referenced this pull request Mar 11, 2026
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 2068908

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
ProgenyAlpha pushed a commit to ProgenyAlpha/llama.cpp that referenced this pull request Mar 12, 2026
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix the fix

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* cont : fix

* metal : add GDN kernel (ggml-org#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(ggml-org#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* CUDA: AR gated delta net improvements (ggml-org#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 2068908

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <devnull@uvos.xyz>

* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
Co-authored-by: Paul Flynn <paul@arkavo.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Oliver Simons <osimons@nvidia.com>
Co-authored-by: uvos <devnull@uvos.xyz>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.