Skip to content

WIP: Qwen3Next#1266

Merged
ikawrakow merged 76 commits intomainfrom
ik/qwen3next
Feb 16, 2026
Merged

WIP: Qwen3Next#1266
ikawrakow merged 76 commits intomainfrom
ik/qwen3next

Conversation

@ikawrakow
Copy link
Owner

@ikawrakow ikawrakow commented Feb 13, 2026

Starting from PR #1251, this is WIP to integrate Qwen3Next.

Massive upgrade to CPU-ony TG performance (7.2 -> 24 t/s on a Ryzen-3995WX), seems to be working correctly.

CPU batch processing is still not fully correct, didn't see yet where is the issue. After various optimizations, CPU PP-512 performance went up from 160 t/s on PR #1251 to 240 t/s in this PR.

Zero changes to CUDA compared to #1251, so CUDA PP is still massively slower than llama.cpp.

Marking as draft for now.

Update Spent a few hours fixing the bad CUDA performance. PP is now better than mainline by about 25%., TG is about 5% lower than mainline. So, I'll remove the draft label, but I think it needs some more work.

Update 2: TG on CUDA should now be on par with mainline.

Update 3: PP on the CPU is now fixed, so the PR is fully functional.

While I was fixing this PR, PR-19375 in mainline was merged. The PR optimizes the Qwen3-Next compute graph, achieving a non-negligible performance improvement. Nothing along these lines has happened in ik_llama.cpp yet, so it is interesting to compare performance. Using llama.cpp build: 27b93cbd1 (8064) for the following. Note that sweep-bench does not work correctly for a recurrent model, so just plain llama-bench comparison. The CUDA runs is with full offload on 2x3090, the CPU-only run is on a Ryzen-3995WX CPU.

model backend test t/s (llama.cpp) t/s (ik_llama.cpp) Speedup
qwen3next 80B.A3B IQ4_XS CPU pp512 98.13 ± 1.27 239.06 ± 6.11 2.436
qwen3next 80B.A3B IQ4_XS CPU tg128 10.17 ± 0.38 23.00 ± 0.04 2.262
qwen3next 80B.A3B IQ4_XS CUDA pp2048 2124.74 ± 2.06 2269.94 ± 33.85 1.068
qwen3next 80B.A3B IQ4_XS CUDA tg128 103.46 ± 0.42 88.20 ± 0.01 0.853

I think this is not too bad considering that mainline developers have been optimizing Qwen3-Next since last November, while this is the very first ik_llama.cpp PR, which was generated by @YurkoHoshko with the help of Codex-5.3, and I started looking at it more seriously just 2 days ago.

yurko and others added 30 commits February 6, 2026 12:13
It was single-threaded and was taking ~25% of the computation time
during TG. It is now down to 2%.

Strangely enough, I measure 13.6 t/s with llama-bench, but if I
let the model give me an actual response with llama-cli, I get close
to 17 t/s.
For Qwen3Next there is a scale op on a largish tensor (548k elements)
that has a single row for TG, so was done in a single thread.
We now simply use blocks of 1024 elements.
@YurkoHoshko YurkoHoshko mentioned this pull request Feb 14, 2026
4 tasks
@ChicoPinto70
Copy link

ChicoPinto70 commented Feb 14, 2026

THANK YOU, Ikawrakow and YurkoHoshko!!! A 80B model running at 12.7 T/s, only in CPU in my Dual Xeon!!!!! Christmas came early this year!!! 👍

@ProgenyAlpha
Copy link

CPU Perplexity Regression — DeltaNet (Qwen3-Coder-Next Q4_K_M)

We've been digging into the CPU backend performance on DeltaNet models and found a significant PPL gap we can't fully explain. Sharing our findings in case you can spot what we're missing — we're still learning the IK fork's ggml internals.

Setup: Qwen3-Coder-Next Q4_K_M, wikitext-2-raw, 2 chunks (512 tokens)

Build Backend Threads PPL
ggml-org mainline CPU 12 4.85
ggml-org mainline Vulkan 4.80
IK (9a34efa) Vulkan 4.81
IK (9a34efa) CPU 1 6.79
IK (9a34efa) CPU 12 7.30
IK (e1fa9e2, base merge) CPU 12 7.30

Two issues we see:

  1. dup/cont threading corruption — At -t 1, CONT source and destination tensor means match exactly (-1141.27). At -t 12, destination mean collapses to ~0.003. Accounts for ~0.55 PPL.

  2. Baseline CPU regression (~2 PPL) — Even at -t 1 with IQK disabled (PPL 6.75), there's a ~2 PPL gap vs mainline CPU. This is present at the base merge commit (e1fa9e2), before any of the optimization chain commits. Vulkan gets correct PPL on the same graph, so the graph construction looks right — something in the CPU kernel path diverges.

We've eliminated IQK mul_mat, cont fusion, mul broadcast, repeat fast paths, and the PR-specific ggml.c changes (sub revert, repeat ne[0]==1, fused delta-net op removal) as causes. The regression appears to come from pre-existing IK main branch ggml.c differences vs ggml-org.

Is this a known issue with DeltaNet on CPU, or can you point us toward what we should be looking at? Happy to run any tests you'd find useful — we have the model and benchmark set up on a dedicated machine. Don't want to keep chasing this if it's already on your radar or if we're fundamentally misunderstanding something.

Needs non-coontiguous variant of sum_rows.
On the CPU this gave 30+% improvement in TG performance,
on CUDA ist is disapointing 6-7%. I guess, this is because
Georgi's cont CPU implementation was so bad that skipping
it made such a big difference.
@ikawrakow
Copy link
Owner Author

@ProgenyAlpha Yes, I know there is still something wrong in the chunked delta-net implementation on the CPU. I said that in the PR description, no?

But the PR is still usable because:

  • Everything works fine on CUDA
  • Many people will use it in hybrid GPU/CPU mode where the attention part is always computed on the GPU, so that's fully correct too
  • Even if prompt processing is not 100% correct when running CPU-only, in my admittedly limited testing it still produces relevant and fully coherent responses when used CPU-only.

I have checked all ops in the chunked delta-net part, and I don't immediately see anything wrong. I guess, I need to invest more time and go over all ops one-by-one.

@ikawrakow ikawrakow marked this pull request as ready for review February 15, 2026 17:49
@YurkoHoshko
Copy link
Contributor

YurkoHoshko commented Feb 15, 2026

Impressive work, @ikawrakow - thank you so much for picking this up and applying your magic to semi-working PR of mine! This is a great unlock (I believe it can be a base for Kimi Linear and also future Qwen 3.5 has very similar architecture) and also a demonstration of how capable ik_llama.cpp is.

From my side, to give credit where the credit is due - @pwilkin did a lot of work to build the initial implementation for llama.cpp - and I don't think Codex would've been able to do so much without a reference implementation in llama.cpp.

@ProgenyAlpha
Copy link

ProgenyAlpha commented Feb 15, 2026

@ikawrakow @YurkoHoshko Thanks! Wasn't trying to say you weren't aware, we just wanted to share our findings after our deep dive. No expectation of a fix at all, just wanted the forensics on record since we'd already spent a full day tracing it down to ggml_compute_forward_dup on transposed DeltaNet tensors. Hopefully, that saves some time whenever you get around to it.

Honestly, massive thanks to both of you for making this happen. Running an 80B MoE with DeltaNet attention on a Radeon 890M is something we didn't think was possible a month ago. Vulkan performance has been rock solid for us.

Really excited to see where this goes with Kimi Linear and Qwen 3.5, if there's anything we can help test or benchmark on the AMD/Vulkan side, we're happy to put cycles into it. This project deserves more contributors and we'd love to give back where we can.

@ikawrakow ikawrakow merged commit e30198a into main Feb 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants