Skip to content

[WIP] Qwen 3 Next experiment#1251

Closed
YurkoHoshko wants to merge 52 commits intoikawrakow:mainfrom
YurkoHoshko:main
Closed

[WIP] Qwen 3 Next experiment#1251
YurkoHoshko wants to merge 52 commits intoikawrakow:mainfrom
YurkoHoshko:main

Conversation

@YurkoHoshko
Copy link
Contributor

@YurkoHoshko YurkoHoshko commented Feb 7, 2026

Disclaimer

This PR was fully AI generated as a test of Codex 5.3 capabilities - this is by no means an optimized version that follows ik_llama.cpp best practices and is not meant to be a contribution in the current shape. Opening this PR per request from ikawrakow on the issue (#1229 (comment)) and is meant to serve mostly as a reference for affected code paths.

Testing methodology
Mainline was cross-referenced throughout the development, will all ops being tested and compared for correctness. Perplexity also seemed within norm.

I ran this model with OpenCode / Pi agents and it seems to work, tool calls are good.

Original bench

**Benchmark**

Command: CUDA_VISIBLE_DEVICES=0,1 build/bin/llama-sweep-bench -m /models/qwen-3-coder-next-mxfp4.gguf -c 8192 -t 8 -fa on --jinja -ngl 999 --n-cpu-moe 25 -rtr --temp 1 --top-p 0.95 --top-k 40 --min-p 0.01

Mainline:

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|   512 |    128 |      0 |    2.822 |   181.41 |    4.496 |    28.47 |
|   512 |    128 |    512 |    2.702 |   189.47 |    4.459 |    28.71 |
|   512 |    128 |   1024 |    2.650 |   193.19 |    4.469 |    28.64 |
|   512 |    128 |   1536 |    2.639 |   194.00 |    4.525 |    28.29 |
|   512 |    128 |   2048 |    2.653 |   192.99 |    4.784 |    26.76 |
|   512 |    128 |   2560 |    2.623 |   195.17 |    4.501 |    28.44 |
|   512 |    128 |   3072 |    2.583 |   198.24 |    4.521 |    28.31 |
|   512 |    128 |   3584 |    2.608 |   196.34 |    4.653 |    27.51 |
|   512 |    128 |   4096 |    2.684 |   190.76 |    4.538 |    28.20 |
|   512 |    128 |   4608 |    2.574 |   198.88 |    4.547 |    28.15 |
|   512 |    128 |   5120 |    2.696 |   189.91 |    4.542 |    28.18 |
|   512 |    128 |   5632 |    2.680 |   191.07 |    4.536 |    28.22 |
|   512 |    128 |   6144 |    2.640 |   193.97 |    4.562 |    28.06 |
|   512 |    128 |   6656 |    2.567 |   199.47 |    4.568 |    28.02 |
|   512 |    128 |   7168 |    2.571 |   199.14 |    4.712 |    27.17 |
|   512 |    128 |   7680 |    2.685 |   190.71 |    4.612 |    27.75 |

This PR:

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|   512 |    128 |      0 |    2.758 |   185.66 |    4.551 |    28.12 |
|   512 |    128 |    512 |    2.584 |   198.15 |    4.543 |    28.18 |
|   512 |    128 |   1024 |    2.599 |   197.00 |    4.541 |    28.19 |
|   512 |    128 |   1536 |    2.711 |   188.89 |    4.583 |    27.93 |
|   512 |    128 |   2048 |    2.624 |   195.09 |    4.676 |    27.37 |
|   512 |    128 |   2560 |    2.662 |   192.33 |    4.591 |    27.88 |
|   512 |    128 |   3072 |    2.689 |   190.38 |    4.614 |    27.74 |
|   512 |    128 |   3584 |    2.669 |   191.80 |    4.585 |    27.91 |
|   512 |    128 |   4096 |    2.644 |   193.65 |    4.610 |    27.77 |
|   512 |    128 |   4608 |    2.627 |   194.92 |    4.575 |    27.98 |
|   512 |    128 |   5120 |    2.661 |   192.37 |    4.574 |    27.99 |
|   512 |    128 |   5632 |    2.614 |   195.84 |    4.577 |    27.97 |
|   512 |    128 |   6144 |    2.704 |   189.33 |    4.606 |    27.79 |
|   512 |    128 |   6656 |    2.608 |   196.28 |    4.564 |    28.05 |
|   512 |    128 |   7168 |    2.680 |   191.04 |    4.663 |    27.45 |
|   512 |    128 |   7680 |    2.660 |   192.50 |    4.677 |    27.37 |

There are few more PRs in the mainline that might be potentially interesting / useful:

Update from Feb 7th

Spent some more time throwing Codex at this problem / PR, moved over some more code - it seems to work, but I can't vouch for the changes because it is over my head :) Got some gains.

New benchmark Qwen3Next Benchmark: PP 16384 / TG 128 (`ik_llama.cpp` vs `llama.cpp`)

Date: 2026-02-08

  • Container: iktest2
  • Model: /models/qwen3-next-coder.gguf
  • Prompt processing: -p 16384
  • Token generation: -n 128
  • Batch settings: -b 3072 -ub 768
  • Threads: -t 8
  • Repetitions: -r 1
  • Mmap: -mmp 0

CUDA runs:

  • CUDA_VISIBLE_DEVICES=0
  • -fa 1 -ngl 999 --n-cpu-moe 47

CPU-only runs:

  • -fa 0 -ngl 0 --n-cpu-moe 0

Hardware note:

  • GPU0 (bench target): NVIDIA GeForce RTX 5060 Ti, 16311 MiB total (CUDA_VISIBLE_DEVICES=0 for CUDA runs).
  • GPU1 (not used for these runs): NVIDIA GeForce RTX 3060, 12288 MiB total.
  • Observed during active ik CUDA run (p=8192,b=2048,ub=512,n-cpu-moe=45): GPU0 memory used ~12074 MiB (~3775 MiB free), from nvidia-smi.

Results

Build Backend PP 16384 (tok/s) TG 128 (tok/s)
ik_llama.cpp CUDA 207.891304 27.263562
llama.cpp CUDA 185.764649 24.145662
ik_llama.cpp CPU-only 45.739881 12.172113
llama.cpp CPU-only 47.835420 6.991398

Relative (ik vs llama.cpp)

  • CUDA PP: +11.91%
  • CUDA TG: +12.91%
  • CPU PP: -4.38%
  • CPU TG: +74.10%

Additional CUDA rerun (requested lower n-cpu-moe ballpark)

Adjusted config:

  • -p 8192 -n 128 -b 2048 -ub 512 -t 8 -fa 1 -ngl 999 -mmp 0
  • single GPU: CUDA_VISIBLE_DEVICES=0

Fit checks on ik:

  • --n-cpu-moe 25 -> fail to load model
  • --n-cpu-moe 40 -> fail to create context
  • --n-cpu-moe 45 -> works

Working comparison at --n-cpu-moe 45:

Build Backend PP 8192 (tok/s) TG 128 (tok/s)
ik_llama.cpp CUDA 201.613283 24.884600
llama.cpp CUDA 145.100895 24.595058

ik rerun with -rtr 1 at the same config (--n-cpu-moe 45):

Build Backend PP 8192 (tok/s) TG 128 (tok/s)
ik_llama.cpp (-rtr 1) CUDA 232.340508 27.895722

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doc was used solely by AI during development and I included it just to keep track of what was going on - please ignore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was used to compare to mainline as coding agent was advancing - please ignore

@YurkoHoshko YurkoHoshko marked this pull request as draft February 7, 2026 20:31
@YurkoHoshko
Copy link
Contributor Author

noticed some stability issues, investigating

@ProgenyAlpha
Copy link

CPU fused DeltaNet threading bug — narrowed down

Hardware: AMD Ryzen AI 9 HX PRO 370 (12c/24t, AVX-512), 96GB DDR5, CPU-only, Q4_K GGUF via Ollama-converted blob.

I've been debugging the CPU-only fused path and can confirm @ikawrakow's finding that it's broken. Here's where it's broken:

The race condition is NOT inside ggml_compute_forward_delta_net_f32. It's in the surrounding ggml_cont_4d(ggml_permute(...)) ops that prepare tensors for the kernel.

Proof:

Config Result
-t 1 fused (any build) Correct output
-t 12 --no-fused-delta Correct output
-t 12 fused (default) Garbage
-t 12 fused, kernel patched to if (params->ith != 0) return; (only thread 0 runs, all 32 heads) Still garbage

The last row is the key finding. Even with the kernel itself forced single-threaded, -t 12 still produces garbage — because the other 11 threads are still running the CONT ops that transform [S,H,T,B]→[S,T,H,B] for q, k, v, g, beta before the kernel executes. The kernel receives already-corrupted inputs.

Additional verification:

  • Kernel math traced step-by-step (head 0, token 0) — numerically correct
  • No state_in/dst pointer aliasing (confirmed at runtime, DELTA_NET not in ggml_op_can_inplace)
  • Thread memory partitioning has no overlaps (each head writes to non-overlapping output/state regions)
  • GGML barriers confirmed solid (OpenMP barrier after every graph node)

The non-fused path avoids this entirely — it uses standard GGML ops (mul_mat, transpose, sum_rows) with no cont_4d(permute(...)) pattern. Copying the mainline implementation as @ikawrakow suggested would sidestep this bug completely.

Separate issue — Ollama GGUF compat: Ollama-converted Q3CN GGUFs omit the .bias suffix on ssm_dt tensors (blk.N.ssm_dt instead of blk.N.ssm_dt.bias). Both upstream llama.cpp and this PR fail to load them without patching the tensor loader to make ssm_dt.bias optional. This affects anyone pulling Q3CN via ollama pull.

Detailed notes and reproduction Dockerfiles: https://github.com/ProgenyAlpha/ik-deltanet-fix

@ikawrakow
Copy link
Owner

More observations for the CPU implementation:

  • Codex has decided that the ik_llama.cpp CPU FA implementation is "unstable", and turns it off by default for Qwen3-Next. Removing this code and running with FA enabled fixes the PPL. Which means that Codex has broken the non-FA path somewhere. However, even if PPL values are normal, generation produces garbage (with default parameters).
  • Confirming the findings by @ProgenyAlpha: fused delta-net is broken on the CPU. Turning fused delta-net off fixes generation. It also has interesting effects on performance: PP increases to 162 t/s (up from 110 t/s with fused delta-net), but TG decreases to 7.25 t/s (which is inline with mainline's CPU TG performance on this system).
  • After observing that fused delta-net is broken, went back to check PPL with fused delta-net off. It is still broken without FA, so really need to fix the non-FA path on the CPU.
  • Haha, PPL with fused delta-net off is significantly higher than with fused delta-net on:

Fused delta-net on, FA on:

[1]3.0178,[2]2.8148,[3]3.1174,[4]3.6971,[5]3.9899,[6]3.6590,[7]3.3702,[8]3.4995,[9]3.7545,[10]3.9129,[11]4.1036,[12]4.3368,

Fused delta-net off, FA on:

[1]4.1737,[2]4.6478,[3]5.1281,[4]5.4141,[5]5.5322,[6]4.8940,[7]4.4075,[8]4.4634,[9]4.7326,[10]4.8577,[11]5.0369,[12]5.2737,

So, basically one needs to turn on fused delta-net for prefill, and turn it off for generation.

@ProgenyAlpha
Copy link

Follow-up: hybrid dispatch test (fused prefill + autoregressive generation)

Based on @ikawrakow's observation that fused gives correct PPL for prefill but broken generation, I patched the dispatch at src/llama-build-context.cpp:~4980 to route T>1 through build_delta_net_fused and T=1 through build_delta_net_autoregressive:

// BEFORE: fused handles everything when enabled
if (use_fused_delta_net) {
    attn_out = build_delta_net_fused(...);
} else {
    attn_out = n_tok == 1 ? build_delta_net_autoregressive(...) : build_delta_net_chunking(...);
}

// AFTER: fused for prefill only, autoregressive for generation
if (use_fused_delta_net && n_tok > 1) {
    attn_out = build_delta_net_fused(...);
} else if (n_tok == 1) {
    attn_out = build_delta_net_autoregressive(...);
} else {
    attn_out = build_delta_net_chunking(...);
}

Result: partial fix. Garbage output is eliminated — first ~100 tokens are correct and on-topic. But generation degrades after that, producing repetition and hallucinated HTML tags. Comparing the same prompt ("Explain how a CPU cache works in 3 paragraphs"):

  • Hybrid (fused PP + auto TG): Correct first paragraph, degenerates mid-second paragraph
  • --no-fused-delta (chunked PP + auto TG): Coherent full 289-token response

The fused kernel and autoregressive path compute DeltaNet state differently, so when autoregressive picks up the state from fused prefill, the mismatch accumulates during generation.

Conclusion: The real fix needs to address the fused kernel's T=1 threading bug in the ggml_cont_4d(ggml_permute(...)) ops directly, rather than routing around it. Or adopt the mainline implementation as suggested.

Updated repo with v15 Dockerfile and results: https://github.com/ProgenyAlpha/ik-deltanet-fix

@ProgenyAlpha
Copy link

Summary: three distinct CPU bugs in the DeltaNet implementation

For tracking purposes — there are three separate bugs, not one:

  1. Fused generation threading bugggml_cont_4d(ggml_permute(...)) ops corrupt tensor data at T=1 with multiple threads. Fused prefill (T>1) is unaffected. This is why -t 1 works but -t 12 produces garbage during generation.

  2. Non-FA path broken — Codex disabled FA by default for Qwen3-Next CPU ("unstable"), and the non-FA code path produces bad PPL ([1]17.5 vs [1]3.0). Per @ikawrakow, removing this override and enabling FA fixes PPL.

  3. Chunked path PPL degradationbuild_delta_net_chunking produces significantly worse perplexity than the fused kernel ([1]4.17 vs [1]3.01 with FA on). This is a separate implementation bug — Codex likely "invented" a different chunking algorithm rather than copying mainline.

Bug 1 blocks using fused for generation. Bug 3 means the chunked fallback isn't equivalent. Together they explain why no single flag combination gives both correct PPL and correct generation on CPU.

@ProgenyAlpha
Copy link

ProgenyAlpha commented Feb 10, 2026

Mainline llama.cpp + Vulkan iGPU: working alternative while CPU DeltaNet is debugged

While the three CPU bugs above are being sorted, we got Q3CN running well on mainline llama.cpp with Vulkan on the same hardware (N5 Pro, HX370, Radeon 890M iGPU, 96GB DDR5).

Setup:

  • Mainline llama.cpp (HEAD), Vulkan backend, RADV/Mesa driver
  • Official \ Q4_K_M (48.4GB split GGUF)
  • BIOS: 16GB UMA frame buffer, ReBAR enabled
  • Docker with \ passthrough

Results (no flash attention):

Metric Value
GPU layers 49/49 (full offload)
Token generation 15.3 tok/s
Prompt processing 57.0 tok/s
Context 4096
Output quality Fully coherent 500+ tokens, no degradation

Results (flash attention ):

Metric No FA FA on Delta
TG (short response) 15.3 tok/s 17.6 tok/s +15%
TG (500 tokens) 15.3 tok/s 15.0 tok/s ~same
PP (short prompt) 57.0 tok/s 53.8 tok/s -6%
PP (long prompt) 57.0 tok/s 48.2 tok/s -15%
KV cache memory 1x 0.5x -50%
Output quality Perfect Perfect Same

Flash attention on Vulkan RDNA 3.5: slight TG boost for short responses, slight PP regression for longer prompts, but halves KV cache memory — enabling 8K-16K context without hitting memory walls.

Key finding: --no-mmap kills UMA performance. On integrated graphics with shared memory, --no-mmap forces double allocation (CPU RAM + GPU-mapped memory from the same DDR5 pool). This caused ErrorOutOfDeviceMemory at 35+ GPU layers even with 16GB UMA + 39GB GTT (55GB total Vulkan-visible). Dropping --no-mmap (using default mmap) let all 49 layers load cleanly.

Also added flash attention results (-fa on on Vulkan RDNA 3.5):

Metric No FA FA on Delta
TG (short response) 15.3 tok/s 17.6 tok/s +15%
TG (500 tokens) 15.3 tok/s 15.0 tok/s ~same
PP (short prompt) 57.0 tok/s 53.8 tok/s -6%
PP (long prompt) 57.0 tok/s 48.2 tok/s -15%
KV cache memory 1x 0.5x -50%

FA halves KV cache memory, enabling 8K-16K context. Slight TG boost for short responses, slight PP regression on longer prompts.

@YurkoHoshko
Copy link
Contributor Author

Thank you - very good advice re. running the model that fully fits on GPU to cover all scenarios - should've done it myself. I also saw other people using tiny models of the same architecture just for numerical verification - a bit above my head, but surely would speed up iteration cycles :) Again - really appreciate your feedback.

I noticed elevated perplexity levels and attempted to fix it yesterday - no progress there yet, will convert this PR back to draft for the moment being.

@YurkoHoshko YurkoHoshko marked this pull request as draft February 11, 2026 04:05
@YurkoHoshko YurkoHoshko changed the title Qwen 3 Next experiment [WIP] Qwen 3 Next experiment Feb 11, 2026
yurko added 6 commits February 10, 2026 23:57
- serialize/restore qwen3next cache.s_l in state/session paths\n- bump session and sequence-state file versions for format change\n- fallback to single-token chunking for mixed repeated seq_id batches
- remove dead build_delta_net_fused lambda\n- remove unused llm_build_context::fused_delta member
- drop -fd/-no-fd options and related YAML dump field\n- remove fused_delta fields from public/internal context params\n- remove fused_delta assignment and logging in context init
@YurkoHoshko
Copy link
Contributor Author

YurkoHoshko commented Feb 13, 2026

Alrighty, I removed a lot of code to keep things a little easier to review and tried to reduce the number of changes to existing files.

At the moment, perplexity seems fixed:

Details

CPU, w/o fa

build/bin/llama-perplexity  -m /models/qwen3-next-coder.gguf -f wikitext-2-raw/wiki.test.raw  -fa 0 -c 2048 -ngl 0 --chunks 8
...
Final estimate: PPL over 8 chunks for n_ctx=2048 = 4.4145 +/- 0.12548

CPU, w/ fa

build/bin/llama-perplexity  -m /models/qwen3-next-coder.gguf -f wikitext-2-raw/wiki.test.raw  -fa 1 -c 2048 -ngl 0 --chunks 8
...
Final estimate: PPL over 8 chunks for n_ctx=2048 = 4.4381 +/- 0.12648

Hybrid, w/ fa

root@2a3f7fa9b1bd:/ik_llama.cpp# build/bin/llama-perplexity  -m /models/qwen3-next-coder.gguf -f wikitext-2-raw/wiki.test.raw -ngl 99 --n-cpu-moe 40 -fa 1 -c 2048 --chunks 8
...
Final estimate: PPL over 8 chunks for n_ctx=2048 = 4.4173 +/- 0.12561

And benches with offload

CUDA_VISIBLE_DEVICES=0 build/bin/llama-sweep-bench -m /models/qwen3-next-coder.gguf -c 8192 -b 2048 -ub 512 -t 8 -fa on -ngl 999 --n-cpu-moe 35 -rtr --temp 1 --top-p 0.95 --top-k 40 --min-p 0.01

main: n_kv_max = 8192, n_batch = 2048, n_ubatch = 512, flash_attn = 1, n_gpu_layers = 999, n_threads = 8, n_threads_batch = 8

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|   512 |    128 |      0 |    2.688 |   190.51 |    4.526 |    28.28 |
|   512 |    128 |    512 |    2.595 |   197.29 |    4.380 |    29.22 |
|   512 |    128 |   1024 |    2.661 |   192.43 |    4.376 |    29.25 |
|   512 |    128 |   1536 |    2.563 |   199.73 |    4.393 |    29.14 |
|   512 |    128 |   2048 |    2.610 |   196.15 |    4.366 |    29.32 |
|   512 |    128 |   2560 |    2.675 |   191.43 |    4.395 |    29.12 |
|   512 |    128 |   3072 |    2.652 |   193.07 |    4.370 |    29.29 |
|   512 |    128 |   3584 |    2.590 |   197.69 |    4.394 |    29.13 |
|   512 |    128 |   4096 |    2.672 |   191.62 |    4.390 |    29.16 |
|   512 |    128 |   4608 |    2.682 |   190.87 |    4.400 |    29.09 |
|   512 |    128 |   5120 |    2.671 |   191.68 |    4.415 |    29.00 |
|   512 |    128 |   5632 |    2.659 |   192.54 |    4.428 |    28.91 |
|   512 |    128 |   6144 |    2.703 |   189.42 |    4.445 |    28.80 |
|   512 |    128 |   6656 |    2.689 |   190.42 |    4.441 |    28.82 |
|   512 |    128 |   7168 |    2.684 |   190.78 |    4.427 |    28.92 |
|   512 |    128 |   7680 |    2.592 |   197.50 |    4.465 |    28.67 |

without offload

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
======================================= HAVE_FANCY_SIMD is defined
512 128 0 5.659 90.47 17.180 7.45
512 128 512 5.082 100.75 15.599 8.21
512 128 1024 5.583 91.71 15.178 8.43
512 128 1536 5.173 98.98 15.477 8.27
512 128 2048 5.684 90.08 15.293 8.37
512 128 2560 5.146 99.49 16.412 7.80
512 128 3072 5.134 99.73 15.276 8.38
512 128 3584 5.915 86.56 18.191 7.04
512 128 4096 5.630 90.95 15.398 8.31
512 128 4608 6.621 77.34 16.310 7.85
512 128 5120 5.909 86.64 16.029 7.99
512 128 5632 6.799 75.31 15.454 8.28
512 128 6144 5.834 87.76 16.714 7.66
512 128 6656 5.935 86.27 15.853 8.07
512 128 7168 6.119 83.67 18.262 7.01
512 128 7680 6.101 83.92 15.067 8.50

The model seems to work and has no major problems with tool calling.
It was able to write a working "snake" game in React in one shot - so I guess it works?

The fix to perplexity seemed to be this one aaa1b12

I am now following up with more things on a separate branch - just wanted to leave something more or less stable here to review.

@YurkoHoshko YurkoHoshko marked this pull request as ready for review February 13, 2026 05:21
@ikawrakow
Copy link
Owner

This looks much better than the previous version. However:

  • CPU PPL is still wrong. Wrong as in higher than CUDA and higher than mainline
  • CUDA PP performance continues to be 3X lower than llama.cpp
  • There are still non-trivial changes left in llama.cpp that I'm not 100% convinced are necessary and work correctly (in the sense that they may brake other models, but haven't tested that yet)
  • When running llama-perplexity with a context of 512 and u-batch size of 2048 we get the following log messages, and completely wrong PPL values for both, CUDA and CPU. Oh, it is also much slower, 3X on CUDA, not sure how many times slower on the CPU as I didn't have the patience to wait until it finishes even one u-batch of 2048.
llama_decode_internal: qwen3next mixed-sequence batch contains repeated seq_id values; falling back to single-token chunking

Do you mind if I take your branch and try to fix it? It will be too tedious to do this via PR comments.

@YurkoHoshko
Copy link
Contributor Author

YurkoHoshko commented Feb 13, 2026

Please, by all means - I would really appreciate it!

Just for my own education: to test CPU-only inference - is it sufficient to set —dev none - or should I recompile it with disabled cuda? - I may have been running my tests incorrectly 🤦 .

There are still non-trivial changes left in llama.cpp that I'm not 100% convinced are necessary and work correctly (in the sense that they may brake other models, but haven't tested that yet)

I ran it side by side with gpt oss 20b for that exact reason - seemed to work (though I need to review my setup to make sure I am doing things correctly - new docker guide is 🔥 )

Thank you for your time!

@ikawrakow
Copy link
Owner

You can run with CUDA_VISIBLE_DEVICES="" your_command_goes_here.

In my case, I prefer to have two separate build folders, one with CUDA enabled (cmake -DGGML_CUDA=ON) and one CPU-only (cmake -DGGML_CUDA=OFF). The advantage of this is that, when working on CPU stuff, I don't need to wait for the long CUDA compilation, and I also don't need to be setting the CUDA devices to empty when I run a test.

@ikawrakow ikawrakow mentioned this pull request Feb 13, 2026
@sar
Copy link

sar commented Feb 13, 2026

Interested to know your thoughts and approach for -sm graph on Qwen3Next architecture, model fits nicely on 4x 24GB GPUs at Q8_0 but non-optimally bound on layer split.

@YurkoHoshko
Copy link
Contributor Author

Closing this PR in favour of #1266

@ikawrakow
Copy link
Owner

@YurkoHoshko I was hoping you will figure out what is wrong with the CPU chunked delta net ;-)

@ikawrakow
Copy link
Owner

Interested to know your thoughts and approach for -sm graph on Qwen3Next architecture, model fits nicely on 4x 24GB GPUs at Q8_0 but non-optimally bound on layer split.

There will be no graph parallel for Qwen3Next. The attention architecture is completely different, and I don't see how one can effectively parallelize it over multiple GPUs.

@YurkoHoshko
Copy link
Contributor Author

@YurkoHoshko I was hoping you will figure out what is wrong with the CPU chunked delta net ;-)

Will look into it over the weekend - apologies for the delay!

@ikawrakow ikawrakow mentioned this pull request Feb 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants