Skip to content

Full graph parallel for Qwen3.5 (dense and MoE)#1388

Merged
ikawrakow merged 22 commits intomainfrom
ik/sm_graph_delta_net
Mar 10, 2026
Merged

Full graph parallel for Qwen3.5 (dense and MoE)#1388
ikawrakow merged 22 commits intomainfrom
ik/sm_graph_delta_net

Conversation

@ikawrakow
Copy link
Owner

@ikawrakow ikawrakow commented Mar 9, 2026

Graph parallel (a..k.a. split mode graph) support for Qwen3-Next and Qwen3-5 was added in PRs #1292, #1331 and #1347. However, these graph parallel implementations are incomplete as recurrent attention layers are still computed on a single GPU.

This PR adds a full graph parallel implementation for the Qwen3.5 models.

The tricky part was not the parallelization of the compute graph, as I was expecting, but extracting the right portions from the recurrent attention tensors for each GPU.

I observe very significant performance improvements - see graph below.

Caveat 1: there PR disables graph parallel for Qwen3-Next. As mentioned above, the tricky part for enabling graph parallel for recurrent attention layers is extracting the correct tensor portions for each GPU. Mainline developers for whatever reason have decided to use a different arrangement for the data in Qwen3-Next and the Qwen3.5 series. Qwen3-Next mostly works, but there is something not 100% correct (yet), so I have disabled graph parallel for now.

Caveat 2: There is an issue when using vision with split mode graph. Hence, I have disabled split mode graph for now when --mmproj is present in the command line arguments.

Caveat 3: Reading and writing the cached recurrent state is not yet implemented for full split mode graph. Looking into it at this point.

The following two graphs show PP-2048 and TG-128 performance as a function of context length for Qwen3.5-27B-IQ4_XS on a 2x3090 system.

q35dense_pp_new1 q35dense_tg_new1

@ikawrakow
Copy link
Owner Author

Here some data for Qwen3.5-35B-A3B. For such small models (small in terms of active parameters) synchronization overhead is significant compared to the time it takes to perform the actual computations, so the graph parallel version on main results in a lower TG performance compared to split mode layer. This PR changes the situation, and we observe better PP and TG performance.

Qwen3.5-35B-A3B-UD-Q4_K_XL.gguf from Unsloth running on 2x3090.

q3535_pp q3535_tg

@ikawrakow
Copy link
Owner Author

And here some hybrid inference results. Model is Qwen3.5-122B-A10B-UD-Q4_K_XL from Unsloth running on 2x3090 using tensor overrides

-ot "blk\.([0-9]|10|11|12)\.ffn_.*_exps\.weight=CPU,blk\.(2[4-9]|3[0-6])\.ffn_.*_exps\.weight=CPU"

(note: we need to split the overrides in 2 chunks so that each of the GPUs holds approximately the same amount of tensor data in VRAM as tensors in layers 0...23 go to CUDA0, tensors in layers 24...47 to CUDA1 when using split mode layer).

q35_122_pp q35_122_tg

@hksdpc255
Copy link
Contributor

I am running on 2×3090 without P2P. Using -sm graph improves the TG speed from around 20 to 25. However, when mmproj is enabled, sending simple "Hello" causes a crash while -sm layer still works.

The model is Qwen3.5-27B, quantized with the following configuration:

Details ``` blk\.0\.attn_gate\.weight=Q8_0 blk\.0\.attn_norm\.weight=BF16 blk\.0\.attn_qkv\.weight=Q6_K blk\.0\.ffn_down\.weight=Q6_K blk\.0\.ffn_gate\.weight=Q6_K blk\.0\.ffn_up\.weight=Q6_K blk\.0\.post_attention_norm\.weight=BF16 blk\.0\.ssm_a=BF16 blk\.0\.ssm_alpha\.weight=BF16 blk\.0\.ssm_beta\.weight=BF16 blk\.0\.ssm_conv1d\.weight=BF16 blk\.0\.ssm_dt\.bias=BF16 blk\.0\.ssm_norm\.weight=BF16 blk\.0\.ssm_out\.weight=BF16 blk\.1\.attn_gate\.weight=Q8_0 blk\.1\.attn_norm\.weight=BF16 blk\.1\.attn_qkv\.weight=Q6_K blk\.1\.ffn_down\.weight=Q6_K blk\.1\.ffn_gate\.weight=Q6_K blk\.1\.ffn_up\.weight=Q6_K blk\.1\.post_attention_norm\.weight=BF16 blk\.1\.ssm_a=BF16 blk\.1\.ssm_alpha\.weight=BF16 blk\.1\.ssm_beta\.weight=BF16 blk\.1\.ssm_conv1d\.weight=BF16 blk\.1\.ssm_dt\.bias=BF16 blk\.1\.ssm_norm\.weight=BF16 blk\.1\.ssm_out\.weight=BF16 blk\.2\.attn_gate\.weight=Q8_0 blk\.2\.attn_norm\.weight=BF16 blk\.2\.attn_qkv\.weight=Q6_K blk\.2\.ffn_down\.weight=Q6_K blk\.2\.ffn_gate\.weight=Q6_K blk\.2\.ffn_up\.weight=Q6_K blk\.2\.post_attention_norm\.weight=BF16 blk\.2\.ssm_a=BF16 blk\.2\.ssm_alpha\.weight=BF16 blk\.2\.ssm_beta\.weight=BF16 blk\.2\.ssm_conv1d\.weight=BF16 blk\.2\.ssm_dt\.bias=BF16 blk\.2\.ssm_norm\.weight=BF16 blk\.2\.ssm_out\.weight=BF16 blk\.3\.attn_k\.weight=BF16 blk\.3\.attn_k_norm\.weight=BF16 blk\.3\.attn_norm\.weight=BF16 blk\.3\.attn_q\.weight=Q8_0 blk\.3\.attn_q_norm\.weight=BF16 blk\.3\.attn_v\.weight=BF16 blk\.3\.ffn_down\.weight=Q6_K blk\.3\.ffn_gate\.weight=Q6_K blk\.3\.ffn_up\.weight=Q6_K blk\.3\.post_attention_norm\.weight=BF16 blk\.3\.attn_output\.weight=Q6_K blk\.4\.attn_gate\.weight=Q8_0 blk\.4\.attn_norm\.weight=BF16 blk\.4\.attn_qkv\.weight=Q6_K blk\.4\.ffn_down\.weight=Q6_K blk\.4\.ffn_gate\.weight=Q6_K blk\.4\.ffn_up\.weight=Q6_K blk\.4\.post_attention_norm\.weight=BF16 blk\.4\.ssm_a=BF16 blk\.4\.ssm_alpha\.weight=BF16 blk\.4\.ssm_beta\.weight=BF16 blk\.4\.ssm_conv1d\.weight=BF16 blk\.4\.ssm_dt\.bias=BF16 blk\.4\.ssm_norm\.weight=BF16 blk\.4\.ssm_out\.weight=BF16 blk\.5\.attn_gate\.weight=Q8_0 blk\.5\.attn_norm\.weight=BF16 blk\.5\.attn_qkv\.weight=Q6_K blk\.5\.ffn_down\.weight=Q6_K blk\.5\.ffn_gate\.weight=Q6_K blk\.5\.ffn_up\.weight=Q6_K blk\.5\.post_attention_norm\.weight=BF16 blk\.5\.ssm_a=BF16 blk\.5\.ssm_alpha\.weight=BF16 blk\.5\.ssm_beta\.weight=BF16 blk\.5\.ssm_conv1d\.weight=BF16 blk\.5\.ssm_dt\.bias=BF16 blk\.5\.ssm_norm\.weight=BF16 blk\.5\.ssm_out\.weight=BF16 blk\.6\.attn_gate\.weight=Q8_0 blk\.6\.attn_norm\.weight=BF16 blk\.6\.attn_qkv\.weight=Q6_K blk\.6\.ffn_down\.weight=Q6_K blk\.6\.ffn_gate\.weight=Q6_K blk\.6\.ffn_up\.weight=Q6_K blk\.6\.post_attention_norm\.weight=BF16 blk\.6\.ssm_a=BF16 blk\.6\.ssm_alpha\.weight=BF16 blk\.6\.ssm_beta\.weight=BF16 blk\.6\.ssm_conv1d\.weight=BF16 blk\.6\.ssm_dt\.bias=BF16 blk\.6\.ssm_norm\.weight=BF16 blk\.6\.ssm_out\.weight=BF16 blk\.7\.attn_k\.weight=BF16 blk\.7\.attn_k_norm\.weight=BF16 blk\.7\.attn_norm\.weight=BF16 blk\.7\.attn_q\.weight=Q8_0 blk\.7\.attn_q_norm\.weight=BF16 blk\.7\.attn_v\.weight=BF16 blk\.7\.ffn_down\.weight=Q6_K blk\.7\.ffn_gate\.weight=Q6_K blk\.7\.ffn_up\.weight=Q6_K blk\.7\.post_attention_norm\.weight=BF16 blk\.7\.attn_output\.weight=Q6_K blk\.8\.attn_gate\.weight=Q8_0 blk\.8\.attn_norm\.weight=BF16 blk\.8\.attn_qkv\.weight=Q6_K blk\.8\.ffn_down\.weight=Q6_K blk\.8\.ffn_gate\.weight=Q6_K blk\.8\.ffn_up\.weight=Q6_K blk\.8\.post_attention_norm\.weight=BF16 blk\.8\.ssm_a=BF16 blk\.8\.ssm_alpha\.weight=BF16 blk\.8\.ssm_beta\.weight=BF16 blk\.8\.ssm_conv1d\.weight=BF16 blk\.8\.ssm_dt\.bias=BF16 blk\.8\.ssm_norm\.weight=BF16 blk\.8\.ssm_out\.weight=BF16 blk\.9\.attn_gate\.weight=Q8_0 blk\.9\.attn_norm\.weight=BF16 blk\.9\.attn_qkv\.weight=Q6_K blk\.9\.ffn_down\.weight=Q6_K blk\.9\.ffn_gate\.weight=Q6_K blk\.9\.ffn_up\.weight=Q6_K blk\.9\.post_attention_norm\.weight=BF16 blk\.9\.ssm_a=BF16 blk\.9\.ssm_alpha\.weight=BF16 blk\.9\.ssm_beta\.weight=BF16 blk\.9\.ssm_conv1d\.weight=BF16 blk\.9\.ssm_dt\.bias=BF16 blk\.9\.ssm_norm\.weight=BF16 blk\.9\.ssm_out\.weight=BF16 blk\.10\.attn_gate\.weight=Q8_0 blk\.10\.attn_norm\.weight=BF16 blk\.10\.attn_qkv\.weight=Q6_K blk\.10\.ffn_down\.weight=Q6_K blk\.10\.ffn_gate\.weight=Q6_K blk\.10\.ffn_up\.weight=Q6_K blk\.10\.post_attention_norm\.weight=BF16 blk\.10\.ssm_a=BF16 blk\.10\.ssm_alpha\.weight=BF16 blk\.10\.ssm_beta\.weight=BF16 blk\.10\.ssm_conv1d\.weight=BF16 blk\.10\.ssm_dt\.bias=BF16 blk\.10\.ssm_norm\.weight=BF16 blk\.10\.ssm_out\.weight=BF16 blk\.11\.attn_k\.weight=BF16 blk\.11\.attn_k_norm\.weight=BF16 blk\.11\.attn_norm\.weight=BF16 blk\.11\.attn_q\.weight=Q8_0 blk\.11\.attn_q_norm\.weight=BF16 blk\.11\.attn_v\.weight=BF16 blk\.11\.ffn_down\.weight=Q8_0 blk\.11\.ffn_gate\.weight=Q6_K blk\.11\.ffn_up\.weight=Q6_K blk\.11\.post_attention_norm\.weight=BF16 blk\.11\.attn_output\.weight=Q6_K blk\.12\.attn_gate\.weight=Q8_0 blk\.12\.attn_norm\.weight=BF16 blk\.12\.attn_qkv\.weight=Q6_K blk\.12\.ffn_down\.weight=Q8_0 blk\.12\.ffn_gate\.weight=Q6_K blk\.12\.ffn_up\.weight=Q6_K blk\.12\.post_attention_norm\.weight=BF16 blk\.12\.ssm_a=BF16 blk\.12\.ssm_alpha\.weight=BF16 blk\.12\.ssm_beta\.weight=BF16 blk\.12\.ssm_conv1d\.weight=BF16 blk\.12\.ssm_dt\.bias=BF16 blk\.12\.ssm_norm\.weight=BF16 blk\.12\.ssm_out\.weight=BF16 blk\.13\.attn_gate\.weight=Q8_0 blk\.13\.attn_norm\.weight=BF16 blk\.13\.attn_qkv\.weight=Q6_K blk\.13\.ffn_down\.weight=Q8_0 blk\.13\.ffn_gate\.weight=Q6_K blk\.13\.ffn_up\.weight=Q6_K blk\.13\.post_attention_norm\.weight=BF16 blk\.13\.ssm_a=BF16 blk\.13\.ssm_alpha\.weight=BF16 blk\.13\.ssm_beta\.weight=BF16 blk\.13\.ssm_conv1d\.weight=BF16 blk\.13\.ssm_dt\.bias=BF16 blk\.13\.ssm_norm\.weight=BF16 blk\.13\.ssm_out\.weight=BF16 blk\.14\.attn_gate\.weight=Q8_0 blk\.14\.attn_norm\.weight=BF16 blk\.14\.attn_qkv\.weight=Q6_K blk\.14\.ffn_down\.weight=Q8_0 blk\.14\.ffn_gate\.weight=Q6_K blk\.14\.ffn_up\.weight=Q6_K blk\.14\.post_attention_norm\.weight=BF16 blk\.14\.ssm_a=BF16 blk\.14\.ssm_alpha\.weight=BF16 blk\.14\.ssm_beta\.weight=BF16 blk\.14\.ssm_conv1d\.weight=BF16 blk\.14\.ssm_dt\.bias=BF16 blk\.14\.ssm_norm\.weight=BF16 blk\.14\.ssm_out\.weight=BF16 blk\.15\.attn_k\.weight=BF16 blk\.15\.attn_k_norm\.weight=BF16 blk\.15\.attn_norm\.weight=BF16 blk\.15\.attn_q\.weight=Q8_0 blk\.15\.attn_q_norm\.weight=BF16 blk\.15\.attn_v\.weight=BF16 blk\.15\.ffn_down\.weight=Q8_0 blk\.15\.ffn_gate\.weight=Q6_K blk\.15\.ffn_up\.weight=Q6_K blk\.15\.post_attention_norm\.weight=BF16 blk\.15\.attn_output\.weight=Q6_K blk\.16\.attn_gate\.weight=Q8_0 blk\.16\.attn_norm\.weight=BF16 blk\.16\.attn_qkv\.weight=Q6_K blk\.16\.ffn_down\.weight=Q8_0 blk\.16\.ffn_gate\.weight=Q6_K blk\.16\.ffn_up\.weight=Q6_K blk\.16\.post_attention_norm\.weight=BF16 blk\.16\.ssm_a=BF16 blk\.16\.ssm_alpha\.weight=BF16 blk\.16\.ssm_beta\.weight=BF16 blk\.16\.ssm_conv1d\.weight=BF16 blk\.16\.ssm_dt\.bias=BF16 blk\.16\.ssm_norm\.weight=BF16 blk\.16\.ssm_out\.weight=BF16 blk\.17\.attn_gate\.weight=Q8_0 blk\.17\.attn_norm\.weight=BF16 blk\.17\.attn_qkv\.weight=Q6_K blk\.17\.ffn_down\.weight=Q6_K blk\.17\.ffn_gate\.weight=Q6_K blk\.17\.ffn_up\.weight=Q6_K blk\.17\.post_attention_norm\.weight=BF16 blk\.17\.ssm_a=BF16 blk\.17\.ssm_alpha\.weight=BF16 blk\.17\.ssm_beta\.weight=BF16 blk\.17\.ssm_conv1d\.weight=BF16 blk\.17\.ssm_dt\.bias=BF16 blk\.17\.ssm_norm\.weight=BF16 blk\.17\.ssm_out\.weight=BF16 blk\.18\.attn_gate\.weight=Q8_0 blk\.18\.attn_norm\.weight=BF16 blk\.18\.attn_qkv\.weight=Q6_K blk\.18\.ffn_down\.weight=Q6_K blk\.18\.ffn_gate\.weight=Q6_K blk\.18\.ffn_up\.weight=Q6_K blk\.18\.post_attention_norm\.weight=BF16 blk\.18\.ssm_a=BF16 blk\.18\.ssm_alpha\.weight=BF16 blk\.18\.ssm_beta\.weight=BF16 blk\.18\.ssm_conv1d\.weight=BF16 blk\.18\.ssm_dt\.bias=BF16 blk\.18\.ssm_norm\.weight=BF16 blk\.18\.ssm_out\.weight=BF16 blk\.19\.attn_k\.weight=BF16 blk\.19\.attn_k_norm\.weight=BF16 blk\.19\.attn_norm\.weight=BF16 blk\.19\.attn_q\.weight=Q8_0 blk\.19\.attn_q_norm\.weight=BF16 blk\.19\.attn_v\.weight=BF16 blk\.19\.ffn_down\.weight=Q6_K blk\.19\.ffn_gate\.weight=Q6_K blk\.19\.ffn_up\.weight=Q6_K blk\.19\.post_attention_norm\.weight=BF16 blk\.19\.attn_output\.weight=Q6_K blk\.20\.attn_gate\.weight=Q8_0 blk\.20\.attn_norm\.weight=BF16 blk\.20\.attn_qkv\.weight=Q6_K blk\.20\.ffn_down\.weight=Q6_K blk\.20\.ffn_gate\.weight=Q6_K blk\.20\.ffn_up\.weight=Q6_K blk\.20\.post_attention_norm\.weight=BF16 blk\.20\.ssm_a=BF16 blk\.20\.ssm_alpha\.weight=BF16 blk\.20\.ssm_beta\.weight=BF16 blk\.20\.ssm_conv1d\.weight=BF16 blk\.20\.ssm_dt\.bias=BF16 blk\.20\.ssm_norm\.weight=BF16 blk\.20\.ssm_out\.weight=BF16 blk\.21\.attn_gate\.weight=Q8_0 blk\.21\.attn_norm\.weight=BF16 blk\.21\.attn_qkv\.weight=Q6_K blk\.21\.ffn_down\.weight=Q6_K blk\.21\.ffn_gate\.weight=Q6_K blk\.21\.ffn_up\.weight=Q6_K blk\.21\.post_attention_norm\.weight=BF16 blk\.21\.ssm_a=BF16 blk\.21\.ssm_alpha\.weight=BF16 blk\.21\.ssm_beta\.weight=BF16 blk\.21\.ssm_conv1d\.weight=BF16 blk\.21\.ssm_dt\.bias=BF16 blk\.21\.ssm_norm\.weight=BF16 blk\.21\.ssm_out\.weight=BF16 blk\.22\.attn_gate\.weight=Q8_0 blk\.22\.attn_norm\.weight=BF16 blk\.22\.attn_qkv\.weight=Q6_K blk\.22\.ffn_down\.weight=Q6_K blk\.22\.ffn_gate\.weight=Q6_K blk\.22\.ffn_up\.weight=Q6_K blk\.22\.post_attention_norm\.weight=BF16 blk\.22\.ssm_a=BF16 blk\.22\.ssm_alpha\.weight=BF16 blk\.22\.ssm_beta\.weight=BF16 blk\.22\.ssm_conv1d\.weight=BF16 blk\.22\.ssm_dt\.bias=BF16 blk\.22\.ssm_norm\.weight=BF16 blk\.22\.ssm_out\.weight=BF16 blk\.23\.attn_k\.weight=BF16 blk\.23\.attn_k_norm\.weight=BF16 blk\.23\.attn_norm\.weight=BF16 blk\.23\.attn_q\.weight=Q8_0 blk\.23\.attn_q_norm\.weight=BF16 blk\.23\.attn_v\.weight=BF16 blk\.23\.ffn_down\.weight=Q6_K blk\.23\.ffn_gate\.weight=Q6_K blk\.23\.ffn_up\.weight=Q6_K blk\.23\.post_attention_norm\.weight=BF16 blk\.23\.attn_output\.weight=Q6_K blk\.24\.attn_gate\.weight=Q8_0 blk\.24\.attn_norm\.weight=BF16 blk\.24\.attn_qkv\.weight=Q6_K blk\.24\.ffn_down\.weight=Q6_K blk\.24\.ffn_gate\.weight=Q6_K blk\.24\.ffn_up\.weight=Q6_K blk\.24\.post_attention_norm\.weight=BF16 blk\.24\.ssm_a=BF16 blk\.24\.ssm_alpha\.weight=BF16 blk\.24\.ssm_beta\.weight=BF16 blk\.24\.ssm_conv1d\.weight=BF16 blk\.24\.ssm_dt\.bias=BF16 blk\.24\.ssm_norm\.weight=BF16 blk\.24\.ssm_out\.weight=BF16 blk\.25\.attn_gate\.weight=Q8_0 blk\.25\.attn_norm\.weight=BF16 blk\.25\.attn_qkv\.weight=Q6_K blk\.25\.ffn_down\.weight=Q6_K blk\.25\.ffn_gate\.weight=Q6_K blk\.25\.ffn_up\.weight=Q6_K blk\.25\.post_attention_norm\.weight=BF16 blk\.25\.ssm_a=BF16 blk\.25\.ssm_alpha\.weight=BF16 blk\.25\.ssm_beta\.weight=BF16 blk\.25\.ssm_conv1d\.weight=BF16 blk\.25\.ssm_dt\.bias=BF16 blk\.25\.ssm_norm\.weight=BF16 blk\.25\.ssm_out\.weight=BF16 blk\.26\.attn_gate\.weight=Q8_0 blk\.26\.attn_norm\.weight=BF16 blk\.26\.attn_qkv\.weight=Q6_K blk\.26\.ffn_down\.weight=Q6_K blk\.26\.ffn_gate\.weight=Q6_K blk\.26\.ffn_up\.weight=Q6_K blk\.26\.post_attention_norm\.weight=BF16 blk\.26\.ssm_a=BF16 blk\.26\.ssm_alpha\.weight=BF16 blk\.26\.ssm_beta\.weight=BF16 blk\.26\.ssm_conv1d\.weight=BF16 blk\.26\.ssm_dt\.bias=BF16 blk\.26\.ssm_norm\.weight=BF16 blk\.26\.ssm_out\.weight=BF16 blk\.27\.attn_k\.weight=BF16 blk\.27\.attn_k_norm\.weight=BF16 blk\.27\.attn_norm\.weight=BF16 blk\.27\.attn_q\.weight=Q8_0 blk\.27\.attn_q_norm\.weight=BF16 blk\.27\.attn_v\.weight=BF16 blk\.27\.ffn_down\.weight=Q6_K blk\.27\.ffn_gate\.weight=Q6_K blk\.27\.ffn_up\.weight=Q6_K blk\.27\.post_attention_norm\.weight=BF16 blk\.27\.attn_output\.weight=Q6_K blk\.28\.attn_gate\.weight=Q8_0 blk\.28\.attn_norm\.weight=BF16 blk\.28\.attn_qkv\.weight=Q6_K blk\.28\.ffn_down\.weight=Q6_K blk\.28\.ffn_gate\.weight=Q6_K blk\.28\.ffn_up\.weight=Q6_K blk\.28\.post_attention_norm\.weight=BF16 blk\.28\.ssm_a=BF16 blk\.28\.ssm_alpha\.weight=BF16 blk\.28\.ssm_beta\.weight=BF16 blk\.28\.ssm_conv1d\.weight=BF16 blk\.28\.ssm_dt\.bias=BF16 blk\.28\.ssm_norm\.weight=BF16 blk\.28\.ssm_out\.weight=BF16 blk\.29\.attn_gate\.weight=Q8_0 blk\.29\.attn_norm\.weight=BF16 blk\.29\.attn_qkv\.weight=Q6_K blk\.29\.ffn_down\.weight=Q6_K blk\.29\.ffn_gate\.weight=Q6_K blk\.29\.ffn_up\.weight=Q6_K blk\.29\.post_attention_norm\.weight=BF16 blk\.29\.ssm_a=BF16 blk\.29\.ssm_alpha\.weight=BF16 blk\.29\.ssm_beta\.weight=BF16 blk\.29\.ssm_conv1d\.weight=BF16 blk\.29\.ssm_dt\.bias=BF16 blk\.29\.ssm_norm\.weight=BF16 blk\.29\.ssm_out\.weight=BF16 blk\.30\.attn_gate\.weight=Q8_0 blk\.30\.attn_norm\.weight=BF16 blk\.30\.attn_qkv\.weight=Q6_K blk\.30\.ffn_down\.weight=Q6_K blk\.30\.ffn_gate\.weight=Q6_K blk\.30\.ffn_up\.weight=Q6_K blk\.30\.post_attention_norm\.weight=BF16 blk\.30\.ssm_a=BF16 blk\.30\.ssm_alpha\.weight=BF16 blk\.30\.ssm_beta\.weight=BF16 blk\.30\.ssm_conv1d\.weight=BF16 blk\.30\.ssm_dt\.bias=BF16 blk\.30\.ssm_norm\.weight=BF16 blk\.30\.ssm_out\.weight=BF16 blk\.31\.attn_k\.weight=BF16 blk\.31\.attn_k_norm\.weight=BF16 blk\.31\.attn_norm\.weight=BF16 blk\.31\.attn_q\.weight=Q8_0 blk\.31\.attn_q_norm\.weight=BF16 blk\.31\.attn_v\.weight=BF16 blk\.31\.ffn_down\.weight=Q6_K blk\.31\.ffn_gate\.weight=Q6_K blk\.31\.ffn_up\.weight=Q6_K blk\.31\.post_attention_norm\.weight=BF16 blk\.31\.attn_output\.weight=Q6_K blk\.32\.attn_gate\.weight=Q8_0 blk\.32\.attn_norm\.weight=BF16 blk\.32\.attn_qkv\.weight=Q6_K blk\.32\.ffn_down\.weight=Q6_K blk\.32\.ffn_gate\.weight=Q6_K blk\.32\.ffn_up\.weight=Q6_K blk\.32\.post_attention_norm\.weight=BF16 blk\.32\.ssm_a=BF16 blk\.32\.ssm_alpha\.weight=BF16 blk\.32\.ssm_beta\.weight=BF16 blk\.32\.ssm_conv1d\.weight=BF16 blk\.32\.ssm_dt\.bias=BF16 blk\.32\.ssm_norm\.weight=BF16 blk\.32\.ssm_out\.weight=BF16 blk\.33\.attn_gate\.weight=Q8_0 blk\.33\.attn_norm\.weight=BF16 blk\.33\.attn_qkv\.weight=Q6_K blk\.33\.ffn_down\.weight=Q6_K blk\.33\.ffn_gate\.weight=Q6_K blk\.33\.ffn_up\.weight=Q6_K blk\.33\.post_attention_norm\.weight=BF16 blk\.33\.ssm_a=BF16 blk\.33\.ssm_alpha\.weight=BF16 blk\.33\.ssm_beta\.weight=BF16 blk\.33\.ssm_conv1d\.weight=BF16 blk\.33\.ssm_dt\.bias=BF16 blk\.33\.ssm_norm\.weight=BF16 blk\.33\.ssm_out\.weight=BF16 blk\.34\.attn_gate\.weight=Q8_0 blk\.34\.attn_norm\.weight=BF16 blk\.34\.attn_qkv\.weight=Q6_K blk\.34\.ffn_down\.weight=Q6_K blk\.34\.ffn_gate\.weight=Q6_K blk\.34\.ffn_up\.weight=Q6_K blk\.34\.post_attention_norm\.weight=BF16 blk\.34\.ssm_a=BF16 blk\.34\.ssm_alpha\.weight=BF16 blk\.34\.ssm_beta\.weight=BF16 blk\.34\.ssm_conv1d\.weight=BF16 blk\.34\.ssm_dt\.bias=BF16 blk\.34\.ssm_norm\.weight=BF16 blk\.34\.ssm_out\.weight=BF16 blk\.35\.attn_k\.weight=BF16 blk\.35\.attn_k_norm\.weight=BF16 blk\.35\.attn_norm\.weight=BF16 blk\.35\.attn_q\.weight=Q8_0 blk\.35\.attn_q_norm\.weight=BF16 blk\.35\.attn_v\.weight=BF16 blk\.35\.ffn_down\.weight=Q6_K blk\.35\.ffn_gate\.weight=Q6_K blk\.35\.ffn_up\.weight=Q6_K blk\.35\.post_attention_norm\.weight=BF16 blk\.35\.attn_output\.weight=Q6_K blk\.36\.attn_gate\.weight=Q8_0 blk\.36\.attn_norm\.weight=BF16 blk\.36\.attn_qkv\.weight=Q6_K blk\.36\.ffn_down\.weight=Q6_K blk\.36\.ffn_gate\.weight=Q6_K blk\.36\.ffn_up\.weight=Q6_K blk\.36\.post_attention_norm\.weight=BF16 blk\.36\.ssm_a=BF16 blk\.36\.ssm_alpha\.weight=BF16 blk\.36\.ssm_beta\.weight=BF16 blk\.36\.ssm_conv1d\.weight=BF16 blk\.36\.ssm_dt\.bias=BF16 blk\.36\.ssm_norm\.weight=BF16 blk\.36\.ssm_out\.weight=BF16 blk\.37\.attn_gate\.weight=Q8_0 blk\.37\.attn_norm\.weight=BF16 blk\.37\.attn_qkv\.weight=Q6_K blk\.37\.ffn_down\.weight=Q6_K blk\.37\.ffn_gate\.weight=Q6_K blk\.37\.ffn_up\.weight=Q6_K blk\.37\.post_attention_norm\.weight=BF16 blk\.37\.ssm_a=BF16 blk\.37\.ssm_alpha\.weight=BF16 blk\.37\.ssm_beta\.weight=BF16 blk\.37\.ssm_conv1d\.weight=BF16 blk\.37\.ssm_dt\.bias=BF16 blk\.37\.ssm_norm\.weight=BF16 blk\.37\.ssm_out\.weight=BF16 blk\.38\.attn_gate\.weight=Q8_0 blk\.38\.attn_norm\.weight=BF16 blk\.38\.attn_qkv\.weight=Q6_K blk\.38\.ffn_down\.weight=Q6_K blk\.38\.ffn_gate\.weight=Q6_K blk\.38\.ffn_up\.weight=Q6_K blk\.38\.post_attention_norm\.weight=BF16 blk\.38\.ssm_a=BF16 blk\.38\.ssm_alpha\.weight=BF16 blk\.38\.ssm_beta\.weight=BF16 blk\.38\.ssm_conv1d\.weight=BF16 blk\.38\.ssm_dt\.bias=BF16 blk\.38\.ssm_norm\.weight=BF16 blk\.38\.ssm_out\.weight=BF16 blk\.39\.attn_k\.weight=BF16 blk\.39\.attn_k_norm\.weight=BF16 blk\.39\.attn_norm\.weight=BF16 blk\.39\.attn_q\.weight=Q8_0 blk\.39\.attn_q_norm\.weight=BF16 blk\.39\.attn_v\.weight=BF16 blk\.39\.ffn_down\.weight=Q6_K blk\.39\.ffn_gate\.weight=Q6_K blk\.39\.ffn_up\.weight=Q6_K blk\.39\.post_attention_norm\.weight=BF16 blk\.39\.attn_output\.weight=Q6_K blk\.40\.attn_gate\.weight=Q8_0 blk\.40\.attn_norm\.weight=BF16 blk\.40\.attn_qkv\.weight=Q6_K blk\.40\.ffn_down\.weight=Q6_K blk\.40\.ffn_gate\.weight=Q6_K blk\.40\.ffn_up\.weight=Q6_K blk\.40\.post_attention_norm\.weight=BF16 blk\.40\.ssm_a=BF16 blk\.40\.ssm_alpha\.weight=BF16 blk\.40\.ssm_beta\.weight=BF16 blk\.40\.ssm_conv1d\.weight=BF16 blk\.40\.ssm_dt\.bias=BF16 blk\.40\.ssm_norm\.weight=BF16 blk\.40\.ssm_out\.weight=BF16 blk\.41\.attn_gate\.weight=Q8_0 blk\.41\.attn_norm\.weight=BF16 blk\.41\.attn_qkv\.weight=Q6_K blk\.41\.ffn_down\.weight=Q6_K blk\.41\.ffn_gate\.weight=Q6_K blk\.41\.ffn_up\.weight=Q6_K blk\.41\.post_attention_norm\.weight=BF16 blk\.41\.ssm_a=BF16 blk\.41\.ssm_alpha\.weight=BF16 blk\.41\.ssm_beta\.weight=BF16 blk\.41\.ssm_conv1d\.weight=BF16 blk\.41\.ssm_dt\.bias=BF16 blk\.41\.ssm_norm\.weight=BF16 blk\.41\.ssm_out\.weight=BF16 blk\.42\.attn_gate\.weight=Q8_0 blk\.42\.attn_norm\.weight=BF16 blk\.42\.attn_qkv\.weight=Q6_K blk\.42\.ffn_down\.weight=Q6_K blk\.42\.ffn_gate\.weight=Q6_K blk\.42\.ffn_up\.weight=Q6_K blk\.42\.post_attention_norm\.weight=BF16 blk\.42\.ssm_a=BF16 blk\.42\.ssm_alpha\.weight=BF16 blk\.42\.ssm_beta\.weight=BF16 blk\.42\.ssm_conv1d\.weight=BF16 blk\.42\.ssm_dt\.bias=BF16 blk\.42\.ssm_norm\.weight=BF16 blk\.42\.ssm_out\.weight=BF16 blk\.43\.attn_k\.weight=BF16 blk\.43\.attn_k_norm\.weight=BF16 blk\.43\.attn_norm\.weight=BF16 blk\.43\.attn_q\.weight=Q8_0 blk\.43\.attn_q_norm\.weight=BF16 blk\.43\.attn_v\.weight=BF16 blk\.43\.ffn_down\.weight=Q6_K blk\.43\.ffn_gate\.weight=Q6_K blk\.43\.ffn_up\.weight=Q6_K blk\.43\.post_attention_norm\.weight=BF16 blk\.43\.attn_output\.weight=Q6_K blk\.44\.attn_gate\.weight=Q8_0 blk\.44\.attn_norm\.weight=BF16 blk\.44\.attn_qkv\.weight=Q6_K blk\.44\.ffn_down\.weight=Q6_K blk\.44\.ffn_gate\.weight=Q6_K blk\.44\.ffn_up\.weight=Q6_K blk\.44\.post_attention_norm\.weight=BF16 blk\.44\.ssm_a=BF16 blk\.44\.ssm_alpha\.weight=BF16 blk\.44\.ssm_beta\.weight=BF16 blk\.44\.ssm_conv1d\.weight=BF16 blk\.44\.ssm_dt\.bias=BF16 blk\.44\.ssm_norm\.weight=BF16 blk\.44\.ssm_out\.weight=BF16 blk\.45\.attn_gate\.weight=Q8_0 blk\.45\.attn_norm\.weight=BF16 blk\.45\.attn_qkv\.weight=Q6_K blk\.45\.ffn_down\.weight=Q6_K blk\.45\.ffn_gate\.weight=Q6_K blk\.45\.ffn_up\.weight=Q6_K blk\.45\.post_attention_norm\.weight=BF16 blk\.45\.ssm_a=BF16 blk\.45\.ssm_alpha\.weight=BF16 blk\.45\.ssm_beta\.weight=BF16 blk\.45\.ssm_conv1d\.weight=BF16 blk\.45\.ssm_dt\.bias=BF16 blk\.45\.ssm_norm\.weight=BF16 blk\.45\.ssm_out\.weight=BF16 blk\.46\.attn_gate\.weight=Q8_0 blk\.46\.attn_norm\.weight=BF16 blk\.46\.attn_qkv\.weight=Q6_K blk\.46\.ffn_down\.weight=Q6_K blk\.46\.ffn_gate\.weight=Q6_K blk\.46\.ffn_up\.weight=Q6_K blk\.46\.post_attention_norm\.weight=BF16 blk\.46\.ssm_a=BF16 blk\.46\.ssm_alpha\.weight=BF16 blk\.46\.ssm_beta\.weight=BF16 blk\.46\.ssm_conv1d\.weight=BF16 blk\.46\.ssm_dt\.bias=BF16 blk\.46\.ssm_norm\.weight=BF16 blk\.46\.ssm_out\.weight=BF16 blk\.47\.attn_k\.weight=BF16 blk\.47\.attn_k_norm\.weight=BF16 blk\.47\.attn_norm\.weight=BF16 blk\.47\.attn_q\.weight=Q8_0 blk\.47\.attn_q_norm\.weight=BF16 blk\.47\.attn_v\.weight=BF16 blk\.47\.ffn_down\.weight=Q6_K blk\.47\.ffn_gate\.weight=Q6_K blk\.47\.ffn_up\.weight=Q6_K blk\.47\.post_attention_norm\.weight=BF16 blk\.47\.attn_output\.weight=Q6_K blk\.48\.attn_gate\.weight=Q8_0 blk\.48\.attn_norm\.weight=BF16 blk\.48\.attn_qkv\.weight=Q6_K blk\.48\.ffn_down\.weight=Q6_K blk\.48\.ffn_gate\.weight=Q6_K blk\.48\.ffn_up\.weight=Q6_K blk\.48\.post_attention_norm\.weight=BF16 blk\.48\.ssm_a=BF16 blk\.48\.ssm_alpha\.weight=BF16 blk\.48\.ssm_beta\.weight=BF16 blk\.48\.ssm_conv1d\.weight=BF16 blk\.48\.ssm_dt\.bias=BF16 blk\.48\.ssm_norm\.weight=BF16 blk\.48\.ssm_out\.weight=BF16 blk\.49\.attn_gate\.weight=Q8_0 blk\.49\.attn_norm\.weight=BF16 blk\.49\.attn_qkv\.weight=Q6_K blk\.49\.ffn_down\.weight=Q6_K blk\.49\.ffn_gate\.weight=Q6_K blk\.49\.ffn_up\.weight=Q6_K blk\.49\.post_attention_norm\.weight=BF16 blk\.49\.ssm_a=BF16 blk\.49\.ssm_alpha\.weight=BF16 blk\.49\.ssm_beta\.weight=BF16 blk\.49\.ssm_conv1d\.weight=BF16 blk\.49\.ssm_dt\.bias=BF16 blk\.49\.ssm_norm\.weight=BF16 blk\.49\.ssm_out\.weight=BF16 blk\.50\.attn_gate\.weight=Q8_0 blk\.50\.attn_norm\.weight=BF16 blk\.50\.attn_qkv\.weight=Q6_K blk\.50\.ffn_down\.weight=Q6_K blk\.50\.ffn_gate\.weight=Q6_K blk\.50\.ffn_up\.weight=Q6_K blk\.50\.post_attention_norm\.weight=BF16 blk\.50\.ssm_a=BF16 blk\.50\.ssm_alpha\.weight=BF16 blk\.50\.ssm_beta\.weight=BF16 blk\.50\.ssm_conv1d\.weight=BF16 blk\.50\.ssm_dt\.bias=BF16 blk\.50\.ssm_norm\.weight=BF16 blk\.50\.ssm_out\.weight=BF16 blk\.51\.attn_k\.weight=BF16 blk\.51\.attn_k_norm\.weight=BF16 blk\.51\.attn_norm\.weight=BF16 blk\.51\.attn_q\.weight=Q8_0 blk\.51\.attn_q_norm\.weight=BF16 blk\.51\.attn_v\.weight=BF16 blk\.51\.ffn_down\.weight=Q8_0 blk\.51\.ffn_gate\.weight=Q8_0 blk\.51\.ffn_up\.weight=Q8_0 blk\.51\.post_attention_norm\.weight=BF16 blk\.51\.attn_output\.weight=Q6_K blk\.52\.attn_gate\.weight=Q8_0 blk\.52\.attn_norm\.weight=BF16 blk\.52\.attn_qkv\.weight=Q6_K blk\.52\.ffn_down\.weight=Q6_K blk\.52\.ffn_gate\.weight=Q6_K blk\.52\.ffn_up\.weight=Q6_K blk\.52\.post_attention_norm\.weight=BF16 blk\.52\.ssm_a=BF16 blk\.52\.ssm_alpha\.weight=BF16 blk\.52\.ssm_beta\.weight=BF16 blk\.52\.ssm_conv1d\.weight=BF16 blk\.52\.ssm_dt\.bias=BF16 blk\.52\.ssm_norm\.weight=BF16 blk\.52\.ssm_out\.weight=BF16 blk\.53\.attn_gate\.weight=Q8_0 blk\.53\.attn_norm\.weight=BF16 blk\.53\.attn_qkv\.weight=Q6_K blk\.53\.ffn_down\.weight=Q6_K blk\.53\.ffn_gate\.weight=Q6_K blk\.53\.ffn_up\.weight=Q6_K blk\.53\.post_attention_norm\.weight=BF16 blk\.53\.ssm_a=BF16 blk\.53\.ssm_alpha\.weight=BF16 blk\.53\.ssm_beta\.weight=BF16 blk\.53\.ssm_conv1d\.weight=BF16 blk\.53\.ssm_dt\.bias=BF16 blk\.53\.ssm_norm\.weight=BF16 blk\.53\.ssm_out\.weight=BF16 blk\.54\.attn_gate\.weight=Q8_0 blk\.54\.attn_norm\.weight=BF16 blk\.54\.attn_qkv\.weight=Q6_K blk\.54\.ffn_down\.weight=Q6_K blk\.54\.ffn_gate\.weight=Q6_K blk\.54\.ffn_up\.weight=Q6_K blk\.54\.post_attention_norm\.weight=BF16 blk\.54\.ssm_a=BF16 blk\.54\.ssm_alpha\.weight=BF16 blk\.54\.ssm_beta\.weight=BF16 blk\.54\.ssm_conv1d\.weight=BF16 blk\.54\.ssm_dt\.bias=BF16 blk\.54\.ssm_norm\.weight=BF16 blk\.54\.ssm_out\.weight=BF16 blk\.55\.attn_k\.weight=BF16 blk\.55\.attn_k_norm\.weight=BF16 blk\.55\.attn_norm\.weight=BF16 blk\.55\.attn_q\.weight=Q8_0 blk\.55\.attn_q_norm\.weight=BF16 blk\.55\.attn_v\.weight=BF16 blk\.55\.ffn_down\.weight=Q6_K blk\.55\.ffn_gate\.weight=Q6_K blk\.55\.ffn_up\.weight=Q6_K blk\.55\.post_attention_norm\.weight=BF16 blk\.55\.attn_output\.weight=Q6_K blk\.56\.attn_gate\.weight=Q8_0 blk\.56\.attn_norm\.weight=BF16 blk\.56\.attn_qkv\.weight=Q6_K blk\.56\.ffn_down\.weight=Q6_K blk\.56\.ffn_gate\.weight=Q6_K blk\.56\.ffn_up\.weight=Q6_K blk\.56\.post_attention_norm\.weight=BF16 blk\.56\.ssm_a=BF16 blk\.56\.ssm_alpha\.weight=BF16 blk\.56\.ssm_beta\.weight=BF16 blk\.56\.ssm_conv1d\.weight=BF16 blk\.56\.ssm_dt\.bias=BF16 blk\.56\.ssm_norm\.weight=BF16 blk\.56\.ssm_out\.weight=BF16 blk\.57\.attn_gate\.weight=Q8_0 blk\.57\.attn_norm\.weight=BF16 blk\.57\.attn_qkv\.weight=Q6_K blk\.57\.ffn_down\.weight=Q6_K blk\.57\.ffn_gate\.weight=Q6_K blk\.57\.ffn_up\.weight=Q6_K blk\.57\.post_attention_norm\.weight=BF16 blk\.57\.ssm_a=BF16 blk\.57\.ssm_alpha\.weight=BF16 blk\.57\.ssm_beta\.weight=BF16 blk\.57\.ssm_conv1d\.weight=BF16 blk\.57\.ssm_dt\.bias=BF16 blk\.57\.ssm_norm\.weight=BF16 blk\.57\.ssm_out\.weight=BF16 blk\.58\.attn_gate\.weight=Q8_0 blk\.58\.attn_norm\.weight=BF16 blk\.58\.attn_qkv\.weight=Q6_K blk\.58\.ffn_down\.weight=Q6_K blk\.58\.ffn_gate\.weight=Q6_K blk\.58\.ffn_up\.weight=Q6_K blk\.58\.post_attention_norm\.weight=BF16 blk\.58\.ssm_a=BF16 blk\.58\.ssm_alpha\.weight=BF16 blk\.58\.ssm_beta\.weight=BF16 blk\.58\.ssm_conv1d\.weight=BF16 blk\.58\.ssm_dt\.bias=BF16 blk\.58\.ssm_norm\.weight=BF16 blk\.58\.ssm_out\.weight=BF16 blk\.59\.attn_k\.weight=BF16 blk\.59\.attn_k_norm\.weight=BF16 blk\.59\.attn_norm\.weight=BF16 blk\.59\.attn_q\.weight=Q8_0 blk\.59\.attn_q_norm\.weight=BF16 blk\.59\.attn_v\.weight=BF16 blk\.59\.ffn_down\.weight=Q8_0 blk\.59\.ffn_gate\.weight=Q8_0 blk\.59\.ffn_up\.weight=Q8_0 blk\.59\.post_attention_norm\.weight=BF16 blk\.59\.attn_output\.weight=Q6_K blk\.60\.attn_gate\.weight=Q8_0 blk\.60\.attn_norm\.weight=BF16 blk\.60\.attn_qkv\.weight=Q6_K blk\.60\.ffn_down\.weight=Q8_0 blk\.60\.ffn_gate\.weight=Q8_0 blk\.60\.ffn_up\.weight=Q8_0 blk\.60\.post_attention_norm\.weight=BF16 blk\.60\.ssm_a=BF16 blk\.60\.ssm_alpha\.weight=BF16 blk\.60\.ssm_beta\.weight=BF16 blk\.60\.ssm_conv1d\.weight=BF16 blk\.60\.ssm_dt\.bias=BF16 blk\.60\.ssm_norm\.weight=BF16 blk\.60\.ssm_out\.weight=BF16 blk\.61\.attn_gate\.weight=Q8_0 blk\.61\.attn_norm\.weight=BF16 blk\.61\.attn_qkv\.weight=Q6_K blk\.61\.ffn_down\.weight=Q8_0 blk\.61\.ffn_gate\.weight=Q8_0 blk\.61\.ffn_up\.weight=Q8_0 blk\.61\.post_attention_norm\.weight=BF16 blk\.61\.ssm_a=BF16 blk\.61\.ssm_alpha\.weight=BF16 blk\.61\.ssm_beta\.weight=BF16 blk\.61\.ssm_conv1d\.weight=BF16 blk\.61\.ssm_dt\.bias=BF16 blk\.61\.ssm_norm\.weight=BF16 blk\.61\.ssm_out\.weight=BF16 blk\.62\.attn_gate\.weight=Q8_0 blk\.62\.attn_norm\.weight=BF16 blk\.62\.attn_qkv\.weight=Q6_K blk\.62\.ffn_down\.weight=Q8_0 blk\.62\.ffn_gate\.weight=Q8_0 blk\.62\.ffn_up\.weight=Q8_0 blk\.62\.post_attention_norm\.weight=BF16 blk\.62\.ssm_a=BF16 blk\.62\.ssm_alpha\.weight=BF16 blk\.62\.ssm_beta\.weight=BF16 blk\.62\.ssm_conv1d\.weight=BF16 blk\.62\.ssm_dt\.bias=BF16 blk\.62\.ssm_norm\.weight=BF16 blk\.62\.ssm_out\.weight=BF16 blk\.63\.attn_k\.weight=BF16 blk\.63\.attn_k_norm\.weight=BF16 blk\.63\.attn_norm\.weight=BF16 blk\.63\.attn_q\.weight=Q8_0 blk\.63\.attn_q_norm\.weight=BF16 blk\.63\.attn_v\.weight=BF16 blk\.63\.ffn_down\.weight=Q8_0 blk\.63\.ffn_gate\.weight=Q8_0 blk\.63\.ffn_up\.weight=Q8_0 blk\.63\.post_attention_norm\.weight=BF16 blk\.63\.attn_output\.weight=Q6_K ```

The quantization recipe is copied from Unsloth’s Q6_K_XL, with all F16 weights replaced by BF16. The mmproj is in BF16.

@magikRUKKOLA
Copy link

magikRUKKOLA commented Mar 9, 2026

ubergarm/Qwen3.5-397B-A17B-GGUF/IQ2_KL; 8x3090

prefill-qwen35-iq2_kl

decode-qwen35-iq2_kl

At the same time, I also got better results for AesSedai IQ3_S. Its 60 tps:

main: n_kv_max = 262144, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, n_gpu_layers = 99, n_threads = 1, n_threads_batch = 1

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
|  2048 |    512 |      0 |    1.735 |  1180.60 |    8.429 |    60.74 |
|  2048 |    512 |   2048 |    1.770 |  1157.36 |    8.410 |    60.88 |
|  2048 |    512 |   4096 |    1.767 |  1159.04 |    8.461 |    60.51 |

@magikRUKKOLA
Copy link

Having a segfault.

======== Cache: cache_size = 0, n_past0 =  0, n_past1 =  0, n_past_prompt1 = 0,  n_past2 =  0, n_past_prompt2 =  0
INFO [    batch_pending_prompt] kv cache rm [p0, end) | tid="140737106661376" timestamp=1773067009 id_slot=0 id_task=0 p0=0
VERB [    batch_pending_prompt] prompt processing progress | tid="140737106661376" timestamp=1773067009 id_slot=0 n_past=17 n_ctx=262144 n_tokens=17 progress=0.7727272510528564
VERB [            update_slots] decoding batch | tid="140737106661376" timestamp=1773067009 n_tokens=17

Thread 1 "llama-server" received signal SIGSEGV, Segmentation fault.
0x00007ffff7cac3b6 in llama_data_write_buffer::get_tensor_data_split(unsigned char*, ggml_tensor const*, ggml_tensor const*, std::vector<unsigned char, std::allocator<unsigned char> >&, unsigned long, unsigned long) ()
   from /opt/ik_llama.cpp/ik_llama.cpp/build/src/libllama.so
(gdb) bt full
#0  0x00007ffff7cac3b6 in llama_data_write_buffer::get_tensor_data_split(unsigned char*, ggml_tensor const*, ggml_tensor const*, std::vector<unsigned char, std::allocator<unsigned char> >&, unsigned long, unsigned long) ()
   from /opt/ik_llama.cpp/ik_llama.cpp/build/src/libllama.so
No symbol table info available.
#1  0x00007ffff7c99ce3 in llama_data_write::write_kv_cache(llama_context const*, int, unsigned int) [clone .constprop.1] ()
   from /opt/ik_llama.cpp/ik_llama.cpp/build/src/libllama.so
No symbol table info available.
#2  0x00007ffff7c9ac63 in llama_state_seq_get_data () from /opt/ik_llama.cpp/ik_llama.cpp/build/src/libllama.so
No symbol table info available.
#3  0x00005555556a703f in server_context::create_checkpoint(server_slot&) ()
No symbol table info available.
#4  0x00005555556d31aa in server_context::process_batch_tokens(int&) ()
No symbol table info available.
#5  0x00005555556d4bcf in server_context::update_slots() ()
No symbol table info available.
#6  0x00005555556708c4 in server_queue::start_loop() ()
No symbol table info available.
#7  0x00005555555e82ff in main ()
No symbol table info available.

@ubergarm
Copy link
Contributor

ubergarm commented Mar 9, 2026

@hksdpc255 ik updated the notes pointing out -sm graph does not currently work with --mmproj

@magikRUKKOLA i'll do some more testing to see if i get any faults, were you using it or sweep benching when that happened?

Looking great on 2xA6000 GPUs:

sweep-bench-Qwen3 5-122B-A10B-IQ4_KSS-PR1388
👈 Details

title: "ik_llama.cpp PR1388 full GPU offload"
subtitle: "ubergarm/Qwen3.5-122B-A10B IQ4_KSS 61.219 GiB (4.306 BPW)"
hardware: "2x RTX A6000 (48GB VRAM each) Driver: 580.105.08 CUDA: 13.0 P2P: OK NCCL found!\n"

ik_llama.cpp main@344688ce

model=/mnt/raid/models/ubergarm/Qwen3.5-122B-A10B-GGUF/Qwen3.5-122B-A10B-IQ4_KSS.gguf
./build/bin/llama-sweep-bench \
  --model "$model" \
  -c 135168 \
  -sm graph \
  -ngl 999 \
  -ub 4096 -b 4096 \
  --threads 1 \
  --no-mmap \
  -n 128 \
  --warmup-batch
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 128 0 1.995 2052.91 2.405 53.21
4096 128 4096 2.038 2010.06 2.366 54.10
4096 128 8192 2.093 1957.31 2.381 53.75
4096 128 12288 2.142 1912.16 2.411 53.09
4096 128 16384 2.213 1851.03 2.415 53.01
4096 128 20480 2.262 1810.90 2.428 52.72
4096 128 24576 2.323 1763.04 2.450 52.25
4096 128 28672 2.374 1725.62 2.458 52.09
4096 128 32768 2.436 1681.53 2.482 51.57
4096 128 36864 2.483 1649.42 2.485 51.51
4096 128 40960 2.539 1613.18 2.494 51.32
4096 128 45056 2.590 1581.39 2.517 50.85
4096 128 49152 2.647 1547.36 2.521 50.76
4096 128 53248 2.695 1520.04 2.528 50.64
4096 128 57344 2.757 1485.82 2.549 50.22
4096 128 61440 2.803 1461.32 2.553 50.13
4096 128 65536 2.847 1438.63 2.576 49.68
4096 128 69632 2.907 1408.89 2.584 49.54
4096 128 73728 2.954 1386.38 2.588 49.47
4096 128 77824 3.009 1361.35 2.608 49.07
4096 128 81920 3.071 1333.92 2.614 48.97
4096 128 86016 3.112 1316.26 2.627 48.73
4096 128 90112 3.171 1291.56 2.644 48.42
4096 128 94208 3.232 1267.21 2.657 48.18
4096 128 98304 3.274 1251.13 2.678 47.79
4096 128 102400 3.329 1230.35 2.682 47.72
4096 128 106496 3.381 1211.61 2.691 47.57
4096 128 110592 3.435 1192.53 2.712 47.20
4096 128 114688 3.486 1175.00 2.713 47.18
4096 128 118784 3.547 1154.73 2.737 46.77
4096 128 122880 3.589 1141.39 2.743 46.66
4096 128 126976 3.645 1123.86 2.744 46.64
4096 128 131072 3.691 1109.67 2.769 46.22

ik_llama.cpp PR1388 ik/sm_graph_delta_net@d1c0acb

model=/mnt/raid/models/ubergarm/Qwen3.5-122B-A10B-GGUF/Qwen3.5-122B-A10B-IQ4_KSS.gguf
./build/bin/llama-sweep-bench \
  --model "$model" \
  -c 135168 \
  -sm graph \
  -ngl 999 \
  -ub 4096 -b 4096 \
  --threads 1 \
  --no-mmap \
  -n 128 \
  --warmup-batch
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
4096 128 0 1.341 3055.34 1.657 77.25
4096 128 4096 1.399 2927.48 1.671 76.61
4096 128 8192 1.447 2830.55 1.674 76.44
4096 128 12288 1.503 2725.59 1.706 75.05
4096 128 16384 1.567 2613.55 1.710 74.85
4096 128 20480 1.629 2514.75 1.719 74.46
4096 128 24576 1.683 2433.20 1.747 73.25
4096 128 28672 1.734 2361.91 1.752 73.07
4096 128 32768 1.789 2289.76 1.779 71.95
4096 128 36864 1.850 2214.35 1.783 71.77
4096 128 40960 1.911 2143.73 1.794 71.36
4096 128 45056 1.955 2094.78 1.816 70.50
4096 128 49152 2.014 2033.58 1.819 70.37
4096 128 53248 2.058 1990.10 1.828 70.04
4096 128 57344 2.110 1941.10 1.850 69.20
4096 128 61440 2.174 1883.73 1.855 68.99
4096 128 65536 2.209 1853.83 1.880 68.08
4096 128 69632 2.275 1800.64 1.887 67.84
4096 128 73728 2.322 1764.14 1.889 67.76
4096 128 77824 2.374 1725.17 1.914 66.88
4096 128 81920 2.432 1684.17 1.920 66.66
4096 128 86016 2.483 1649.71 1.939 66.02
4096 128 90112 2.540 1612.81 1.957 65.39
4096 128 94208 2.581 1587.01 1.963 65.21
4096 128 98304 2.646 1548.28 1.984 64.51
4096 128 102400 2.694 1520.29 1.993 64.21
4096 128 106496 2.743 1493.35 2.002 63.92
4096 128 110592 2.796 1465.17 2.023 63.27
4096 128 114688 2.859 1432.56 2.027 63.15
4096 128 118784 2.897 1413.94 2.053 62.35
4096 128 122880 2.964 1381.86 2.058 62.20
4096 128 126976 3.006 1362.45 2.062 62.06
4096 128 131072 3.049 1343.55 2.086 61.37

@ikawrakow
Copy link
Owner Author

ikawrakow commented Mar 9, 2026

Yes, updated the description with a 3rd caveat: reading/writing of the cached recurrent state is not yet implemented, so it crashes.

@ubergarm
Copy link
Contributor

ubergarm commented Mar 9, 2026

Hrm, when running llama-server it is immediately segfaulting after i send a prompt... Is there a way to explicitly disable reading/writing cached state (i tried --cache-ram 0)?

Here is the gdb output recompiled in debug mode:

👈 Details
gdb -q --args \
./build/bin/llama-server \
  --model "$model" \
  --alias Qwen3.5-122B-A10B \
  -c 262144 \
  -sm graph \
  -ngl 99 \
  -ub 4096 -b 4096 \
  --parallel 1 \
  --threads 1 \
  --host 127.0.0.1 \
  --port 8080 \
  --jinja \
  --no-mmap \
  --cache-ram 0

(gdb) set print thread-events off
(gdb) run
...

No tensors in buffer type CUDA0
llm_load_tensors: offloading 48 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 49/49 layers to GPU
llm_load_tensors:  CUDA_Host buffer size =   602.46 MiB
llm_load_tensors:      CUDA1 buffer size =   602.47 MiB
llm_load_tensors: CUDA_Split buffer size = 61773.38 MiB
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->0
~ggml_backend_cuda_context: have 0 graphs
....................................................................................................
llama_init_from_model: n_ctx         = 262144
llama_init_from_model: n_batch       = 4096
llama_init_from_model: n_ubatch      = 4096
llama_init_from_model: flash_attn    = 1
llama_init_from_model: attn_max_b    = 0
llama_init_from_model: fused_moe     = 1
llama_init_from_model: grouped er    = 0
llama_init_from_model: fused_up_gate = 1
llama_init_from_model: fused_mmad    = 1
llama_init_from_model: rope_cache    = 0
llama_init_from_model: graph_reuse   = 1
llama_init_from_model: k_cache_hadam = 0
llama_init_from_model: split_mode_graph_scheduling = 0
llama_init_from_model: reduce_type   = f16
llama_init_from_model: sched_async   = 0
llama_init_from_model: ser           = -1, 0
llama_init_from_model: freq_base     = 10000000.0
llama_init_from_model: freq_scale    = 1
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 0->1
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->0
=== Created recurrent cache cache_s_l0 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l1 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l2 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l4 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l5 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l6 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l8 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l9 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l10 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l12 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l13 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l14 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l16 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l17 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l18 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l20 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l21 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l22 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l24 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l25 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l26 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l28 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l29 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l30 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l32 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l33 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l34 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l36 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l37 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l38 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l40 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l41 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l42 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l44 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l45 as 1085440 x 1 x 1 x 1
=== Created recurrent cache cache_s_l46 as 1085440 x 1 x 1 x 1
llama_kv_cache_init: CUDA_Split KV buffer size =  6293.07 MiB
llama_kv_cache_init: KV cache size per device:
    Device 0:  3146.53 MiB
    Device 1:  3146.53 MiB
llama_init_from_model: KV self size  = 6144.00 MiB, K (f16): 3072.00 MiB, V (f16): 3072.00 MiB
llama_init_from_model:  CUDA_Host  output buffer size =     0.95 MiB
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 0.00 MiB to 2592.02 MiB
ggml_gallocr_reserve_n: reallocating CUDA1 buffer from size 0.00 MiB to 3976.00 MiB
ggml_gallocr_reserve_n: reallocating CUDA_Host buffer from size 0.00 MiB to 2096.11 MiB
llama_init_from_model:      CUDA0 compute buffer size =  2592.02 MiB
llama_init_from_model:      CUDA1 compute buffer size =  3976.00 MiB
llama_init_from_model:  CUDA_Host compute buffer size =  2096.11 MiB
llama_init_from_model: graph nodes  = 6822
llama_init_from_model: graph splits = 289
llama_init_from_model: enabling only_active_experts scheduling
ggml_backend_sched_alloc_splits: failed to allocate graph, reserving (backend_ids_changed = 1)
INFO [                    init] initializing slots | tid="140737350545408" timestamp=1773068930 n_slots=1
INFO [                    init] new slot | tid="140737350545408" timestamp=1773068930 id_slot=0 n_ctx_slot=262144
srv          init: Exclude reasoning tokens when selecting slot based on similarity: start: <think>, end: </think>
use `--reasoning-tokens none` to disable.
ggml_backend_sched_alloc_splits: failed to allocate graph, reserving (backend_ids_changed = 1)
update_cuda_graph_executable: CUDA graph update failed
no implementations specified for speculative decoding
slot         init: id  0 | task -1 | speculative decoding context not initialized
prompt cache is disabled - use `--cache-ram N` to enable it
INFO [                    main] model loaded | tid="140737350545408" timestamp=1773068930
INFO [                    main] chat template | tid="140737350545408" timestamp=1773068930 chat_template="..."
INFO [                    main] HTTP server listening | tid="140737350545408" timestamp=1773068930 n_threads_http="47" port="8080" hostname="127.0.0.1
"
INFO [              slots_idle] all slots are idle | tid="140737350545408" timestamp=1773068930
======== Prompt cache: cache size: 0, n_keep: 0, n_discarded_prompt: 0, cache_ram_n_min: 0, f_keep: 0.00, cache_ram_similarity: 0.50
Grammar: bash-any-arg ::= func-bash-kv-command | func-bash-kv-timeout | func-bash-kv-workdir | func-bash-kv-description
bash-any-arg-with-end ::= bash-any-arg bash-last-arg-end
bash-args-relaxed ::= ( bash-any-arg-with-end )*
bash-call ::= "\n<tool_call>\n<function=" "bash" ">\n" bash-args-relaxed
bash-last-arg-end ::= "\n</parameter>\n"
boolean ::= ("true" | "false") space
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
decimal-part ::= [0-9]{1,16}
edit-any-arg ::= func-edit-kv-filePath | func-edit-kv-oldString | func-edit-kv-newString | func-edit-kv-replaceAll
edit-any-arg-with-end ::= edit-any-arg edit-last-arg-end
edit-args-relaxed ::= ( edit-any-arg-with-end )*
edit-call ::= "\n<tool_call>\n<function=" "edit" ">\n" edit-args-relaxed
...
write-last-arg-end ::= "\n</parameter>\n"

Grammar lazy: true
Chat format: Qwen3 Coder
INFO [   launch_slot_with_task] slot is processing task | tid="140737350545408" timestamp=1773069254 id_slot=0 id_task=0
======== Cache: cache_size = 0, n_past0 =  0, n_past1 =  0, n_past_prompt1 = 0,  n_past2 =  0, n_past_prompt2 =  0
INFO [    batch_pending_prompt] kv cache rm [p0, end) | tid="140737350545408" timestamp=1773069254 id_slot=0 id_task=0 p0=0
ggml_backend_sched_alloc_splits: failed to allocate graph, reserving (backend_ids_changed = 1)
update_cuda_graph_executable: CUDA graph update failed

Thread 1 "llama-server" received signal SIGSEGV, Segmentation fault.
0x00007ffff7916acb in llama_data_write_buffer::get_tensor_data_split (ptr=0x7feabeaef4b0 "", tensor=0x5555594c0420, kv=0x0,
    aux_buffer=std::vector of length 0, capacity 0, offset=0, size=4341760) at /home/w/projects/ik_llama.cpp/src/llama.cpp:6665
6665            auto ne = kv->ne[1];
(gdb) bt
#0  0x00007ffff7916acb in llama_data_write_buffer::get_tensor_data_split (ptr=0x7feabeaef4b0 "", tensor=0x5555594c0420, kv=0x0,
    aux_buffer=std::vector of length 0, capacity 0, offset=0, size=4341760) at /home/w/projects/ik_llama.cpp/src/llama.cpp:6665
#1  0x00007ffff79169e9 in llama_data_write_buffer::get_tensor_data_split (this=0x7fffffff91c0, tensor=0x5555594c0420, offset=0, size=4341760, il=0)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6660
#2  0x00007ffff7916716 in llama_data_write_buffer::write_tensor_data (this=0x7fffffff91c0, tensor=0x5555594c0420, offset=0, size=4341760, il=0)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6644
#3  0x00007ffff7913427 in llama_data_write::write_kv_cache_data (this=0x7fffffff91c0, ctx=0x55555d0d5340,
    cell_ranges=std::vector of length 1, capacity 1 = {...}, seq_id=0, flags=1) at /home/w/projects/ik_llama.cpp/src/llama.cpp:6067
#4  0x00007ffff7913764 in llama_data_write::write_kv_cache (this=0x7fffffff91c0, ctx=0x55555d0d5340, seq_id=0, flags=1)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6109
#5  0x00007ffff78fec35 in llama_state_seq_get_data_internal (ctx=0x55555d0d5340, data_ctx=..., seq_id=0, flags=1)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6952
#6  0x00007ffff78fed6c in llama_state_seq_get_data (ctx=0x55555d0d5340, dst=0x7feabeae7010 "", size=156338064, seq_id=0, flags=1)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6965
#7  0x000055555576bac3 in server_context::create_checkpoint (this=0x7fffffffca20, slot=...)
    at /home/w/projects/ik_llama.cpp/examples/server/server-context.cpp:2754
#8  0x000055555576ae3e in server_context::create_checkpoint_at_interval (this=0x7fffffffca20, slot=..., params_base=...)
    at /home/w/projects/ik_llama.cpp/examples/server/server-context.cpp:2652
#9  0x0000555555773450 in server_context::process_batch_tokens (this=0x7fffffffca20, n_batch=@0x7fffffff9664: 4096)
    at /home/w/projects/ik_llama.cpp/examples/server/server-context.cpp:3432
#10 0x000055555577451e in server_context::update_slots (this=0x7fffffffca20) at /home/w/projects/ik_llama.cpp/examples/server/server-context.cpp:3582
#11 0x00005555556af0f9 in std::__invoke_impl<void, void (server_context::*&)(), server_context*&> (
    __f=@0x555568e717c0: (void (server_context::*)(server_context * const)) 0x5555557740e2 <server_context::update_slots()>,
    __t=@0x555568e717d0: 0x7fffffffca20) at /usr/include/c++/13/bits/invoke.h:74
#12 0x00005555556a85b7 in std::__invoke<void (server_context::*&)(), server_context*&> (
    __fn=@0x555568e717c0: (void (server_context::*)(server_context * const)) 0x5555557740e2 <server_context::update_slots()>)
    at /usr/include/c++/13/bits/invoke.h:96
#13 0x000055555569dd89 in std::_Bind<void (server_context::*(server_context*))()>::__call<void, , 0ul>(std::tuple<>&&, std::_Index_tuple<0ul>) (
    this=0x555568e717c0, __args=...) at /usr/include/c++/13/functional:506
#14 0x0000555555696095 in std::_Bind<void (server_context::*(server_context*))()>::operator()<, void>() (this=0x555568e717c0)
    at /usr/include/c++/13/functional:591
#15 0x0000555555689298 in std::__invoke_impl<void, std::_Bind<void (server_context::*(server_context*))()>&>(std::__invoke_other, std::_Bind<void (ser
ver_context::*(server_context*))()>&) (__f=...) at /usr/include/c++/13/bits/invoke.h:61
#16 0x000055555567d530 in std::__invoke_r<void, std::_Bind<void (server_context::*(server_context*))()>&>(std::_Bind<void (server_context::*(server_co
ntext*))()>&) (__fn=...) at /usr/include/c++/13/bits/invoke.h:111
#17 0x0000555555669959 in std::_Function_handler<void (), std::_Bind<void (server_context::*(server_context*))()> >::_M_invoke(std::_Any_data const&)
    (__functor=...) at /usr/include/c++/13/bits/std_function.h:290
--Type <RET> for more, q to quit, c to continue without paging--
#18 0x000055555564c11a in std::function<void ()>::operator()() const (this=0x7fffffffdcc0) at /usr/include/c++/13/bits/std_function.h:591
#19 0x00005555556f94c1 in server_queue::start_loop (this=0x7fffffffdb68) at /home/w/projects/ik_llama.cpp/examples/server/server-queue.cpp:133
#20 0x000055555562c72a in main (argc=27, argv=0x7fffffffdfd8) at /home/w/projects/ik_llama.cpp/examples/server/server.cpp:2139
(gdb) info args
ptr = 0x7feabeaef4b0 ""
tensor = 0x5555594c0420
kv = 0x0
aux_buffer = std::vector of length 0, capacity 0
offset = 0
size = 4341760
(gdb) info locals
ne = 0
full_row_size = 0
first_row = 0
num_rows = 0
extra = 0xec2b22222d21282b
kv_extra = 0xbfbf2e32ed352d24
split_offset = 0
total_size = 0

@ikawrakow
Copy link
Owner Author

@magikRUKKOLA @ubergarm

I just pushed a change. Does it solve the segfault?

@ubergarm
Copy link
Contributor

ubergarm commented Mar 9, 2026

Still faulting immediately after sending prompt short test prompt from the web ui.

ik_llama.cpp/ggml/src/ggml-backend.cpp:262: GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds") failed

gdb output below:

👈 Details
INFO [                    main] HTTP server listening | tid="140737350545408" timestamp=1773070501 n_threads_http="47" port="8080" hostname="127.0.0.1
"
INFO [              slots_idle] all slots are idle | tid="140737350545408" timestamp=1773070501
INFO [      log_server_request] request | tid="140733403643904" timestamp=1773070535 remote_addr="127.0.0.1" remote_port=46624 status=200 method="GET"
 path="/" params={}
INFO [      log_server_request] request | tid="140733403643904" timestamp=1773070535 remote_addr="127.0.0.1" remote_port=46624 status=200 method="GET"
 path="/v1/props" params={}
INFO [      log_server_request] request | tid="140733403643904" timestamp=1773070539 remote_addr="127.0.0.1" remote_port=46624 status=200 method="GET"
 path="/v1/props" params={}
INFO [      log_server_request] request | tid="140733403643904" timestamp=1773070539 remote_addr="127.0.0.1" remote_port=46624 status=200 method="GET"
 path="/v1/props" params={}
======== Prompt cache: cache size: 0, n_keep: 0, n_discarded_prompt: 0, cache_ram_n_min: 0, f_keep: 0.00, cache_ram_similarity: 0.50
Grammar:
Grammar lazy: false
Chat format: Qwen3 Coder
INFO [   launch_slot_with_task] slot is processing task | tid="140737350545408" timestamp=1773070540 id_slot=0 id_task=0
======== Cache: cache_size = 0, n_past0 =  0, n_past1 =  0, n_past_prompt1 = 0,  n_past2 =  0, n_past_prompt2 =  0
INFO [    batch_pending_prompt] kv cache rm [p0, end) | tid="140737350545408" timestamp=1773070540 id_slot=0 id_task=0 p0=0
ggml_backend_sched_alloc_splits: failed to allocate graph, reserving (backend_ids_changed = 1)
update_cuda_graph_executable: CUDA graph update failed
update_cuda_graph_executable: CUDA graph update failed
...
update_cuda_graph_executable: CUDA graph update failed
update_cuda_graph_executable: CUDA graph update failed
/home/w/projects/ik_llama.cpp/ggml/src/ggml-backend.cpp:262: GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds") failed
[Detaching after fork from child process 1087374]
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
warning: process 1087250 is already traced by process 1087240
ptrace: Operation not permitted.
No stack.
The program is not being run.

Thread 1 "llama-server" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (no_tid=0, signo=6, threadid=<optimized out>) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=<optimized out>) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=<optimized out>) at ./nptl/pthread_kill.c:78
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3  0x00007fffc264527e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffc26288ff in __GI_abort () at ./stdlib/abort.c:79
#5  0x00007fffc2ea906c in ggml_abort (file=0x7fffc4e23b80 "/home/w/projects/ik_llama.cpp/ggml/src/ggml-backend.cpp", line=262,
    fmt=0x7fffc4e23b65 "GGML_ASSERT(%s) failed") at /home/w/projects/ik_llama.cpp/ggml/src/ggml.c:264
#6  0x00007fffc2f150ed in ggml_backend_tensor_get (tensor=0x5555594c0590, data=0x55556cc87c40, offset=0, size=4341760)
    at /home/w/projects/ik_llama.cpp/ggml/src/ggml-backend.cpp:262
#7  0x00007ffff7917200 in llama_data_write_buffer::get_tensor_data_split (ptr=0x7fead32c4588 "", tensor=0x5555594c0420,
    aux_buffer=std::vector of length 4341760, capacity 4341760 = {...}, offset=0, size=4341760) at /home/w/projects/ik_llama.cpp/src/llama.cpp:6723
#8  0x00007ffff7916aa0 in llama_data_write_buffer::get_tensor_data_split (this=0x7fffffff9200, tensor=0x5555594c0420, offset=0, size=4341760, il=0)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6662
#9  0x00007ffff79167e6 in llama_data_write_buffer::write_tensor_data (this=0x7fffffff9200, tensor=0x5555594c0420, offset=0, size=4341760, il=0)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6646
#10 0x00007ffff7913447 in llama_data_write::write_kv_cache_data (this=0x7fffffff9200, ctx=0x55555d0d5340,
    cell_ranges=std::vector of length 1, capacity 1 = {...}, seq_id=0, flags=1) at /home/w/projects/ik_llama.cpp/src/llama.cpp:6067
#11 0x00007ffff7913784 in llama_data_write::write_kv_cache (this=0x7fffffff9200, ctx=0x55555d0d5340, seq_id=0, flags=1)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6109
#12 0x00007ffff78fec55 in llama_state_seq_get_data_internal (ctx=0x55555d0d5340, data_ctx=..., seq_id=0, flags=1)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:6990
#13 0x00007ffff78fed8c in llama_state_seq_get_data (ctx=0x55555d0d5340, dst=0x7fead32c4010 "\033", size=156305512, seq_id=0, flags=1)
    at /home/w/projects/ik_llama.cpp/src/llama.cpp:7003
#14 0x000055555576bac3 in server_context::create_checkpoint (this=0x7fffffffca20, slot=...)
    at /home/w/projects/ik_llama.cpp/examples/server/server-context.cpp:2754
#15 0x0000555555773427 in server_context::process_batch_tokens (this=0x7fffffffca20, n_batch=@0x7fffffff9664: 4096)
    at /home/w/projects/ik_llama.cpp/examples/server/server-context.cpp:3430
#16 0x000055555577451e in server_context::update_slots (this=0x7fffffffca20) at /home/w/projects/ik_llama.cpp/examples/server/server-context.cpp:3582
#17 0x00005555556af0f9 in std::__invoke_impl<void, void (server_context::*&)(), server_context*&> (
    __f=@0x555568e717c0: (void (server_context::*)(struct server_context * const)) 0x5555557740e2 <server_context::update_slots()>,
    __t=@0x555568e717d0: 0x7fffffffca20) at /usr/include/c++/13/bits/invoke.h:74
#18 0x00005555556a85b7 in std::__invoke<void (server_context::*&)(), server_context*&> (
    __fn=@0x555568e717c0: (void (server_context::*)(struct server_context * const)) 0x5555557740e2 <server_context::update_slots()>)
    at /usr/include/c++/13/bits/invoke.h:96
#19 0x000055555569dd89 in std::_Bind<void (server_context::*(server_context*))()>::__call<void, , 0ul>(std::tuple<>&&, std::_Index_tuple<0ul>) (
    this=0x555568e717c0, __args=...) at /usr/include/c++/13/functional:506
#20 0x0000555555696095 in std::_Bind<void (server_context::*(server_context*))()>::operator()<, void>() (this=0x555568e717c0)
--Type <RET> for more, q to quit, c to continue without paging--
    at /usr/include/c++/13/functional:591
#21 0x0000555555689298 in std::__invoke_impl<void, std::_Bind<void (server_context::*(server_context*))()>&>(std::__invoke_other, std::_Bind<void (ser
ver_context::*(server_context*))()>&) (__f=...) at /usr/include/c++/13/bits/invoke.h:61
#22 0x000055555567d530 in std::__invoke_r<void, std::_Bind<void (server_context::*(server_context*))()>&>(std::_Bind<void (server_context::*(server_co
ntext*))()>&) (__fn=...) at /usr/include/c++/13/bits/invoke.h:111
#23 0x0000555555669959 in std::_Function_handler<void (), std::_Bind<void (server_context::*(server_context*))()> >::_M_invoke(std::_Any_data const&)
    (__functor=...) at /usr/include/c++/13/bits/std_function.h:290
#24 0x000055555564c11a in std::function<void ()>::operator()() const (this=0x7fffffffdcc0) at /usr/include/c++/13/bits/std_function.h:591
#25 0x00005555556f94c1 in server_queue::start_loop (this=0x7fffffffdb68) at /home/w/projects/ik_llama.cpp/examples/server/server-queue.cpp:133
#26 0x000055555562c72a in main (argc=27, argv=0x7fffffffdfd8) at /home/w/projects/ik_llama.cpp/examples/server/server.cpp:2139
(gdb) info args
no_tid = 0
signo = 6
threadid = <optimized out>
(gdb) info locals
tid = <optimized out>
ret = 0
pd = <optimized out>
old_mask = {__val = {4702111234474983745}}
ret = <optimized out>
pd = <optimized out>
old_mask = <optimized out>
ret = <optimized out>
tid = <optimized out>
ret = <optimized out>
resultvar = <optimized out>
resultvar = <optimized out>
__arg3 = <optimized out>
__arg2 = <optimized out>
__arg1 = <optimized out>
_a3 = <optimized out>
_a2 = <optimized out>
_a1 = <optimized out>
__futex = <optimized out>
resultvar = <optimized out>
__arg3 = <optimized out>
__arg2 = <optimized out>
__arg1 = <optimized out>
_a3 = <optimized out>
_a2 = <optimized out>
_a1 = <optimized out>
__futex = <optimized out>
__private = <optimized out>
__oldval = <optimized out>

(gdb) frame 6
#6  0x00007fffc2f150ed in ggml_backend_tensor_get (tensor=0x5555594c0590, data=0x55556cc87c40, offset=0, size=4341760)
    at /home/w/projects/ik_llama.cpp/ggml/src/ggml-backend.cpp:262
262         GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");

(gdb) info frame
Stack level 6, frame at 0x7fffffff8d50:
 rip = 0x7fffc2f150ed in ggml_backend_tensor_get (/home/w/projects/ik_llama.cpp/ggml/src/ggml-backend.cpp:262); saved rip = 0x7ffff7917200
 called by frame at 0x7fffffff8e10, caller of frame at 0x7fffffff8d00
 source language c++.
 Arglist at 0x7fffffff8d40, args: tensor=0x5555594c0590, data=0x55556cc87c40, offset=0, size=4341760
 Locals at 0x7fffffff8d40, Previous frame's sp is 0x7fffffff8d50
 Saved registers:
  rbx at 0x7fffffff8d38, rbp at 0x7fffffff8d40, rip at 0x7fffffff8d48

@magikRUKKOLA
Copy link

@ikawrakow

                  
/opt/ik_llama.cpp/ik_llama.cpp/ggml/src/ggml-backend.cpp:262: GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor re
ad out of bounds") failed                                                                                   

@hksdpc255
Copy link
Contributor

Even setups without p2p gpu access will benifit from this. What a nice work!

ubergarm/Qwen3.5-27B-GGUF/Qwen3.5-27B-smol-IQ4_NL.gguf, no p2p, 2*3090

without PR, -sm graph

Details
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 1.483 1381.14 11.781 43.46
2048 512 2048 1.463 1399.46 11.757 43.55
2048 512 4096 1.462 1400.62 11.927 42.93
2048 512 6144 1.490 1374.86 12.070 42.42
2048 512 8192 1.515 1352.18 12.301 41.62
2048 512 10240 1.549 1322.51 12.526 40.87
2048 512 12288 1.558 1314.59 12.627 40.55
2048 512 14336 1.575 1300.46 12.777 40.07
2048 512 16384 1.596 1283.43 13.057 39.21
2048 512 18432 1.608 1273.46 13.158 38.91
2048 512 20480 1.627 1258.62 13.234 38.69
2048 512 22528 1.648 1242.54 13.314 38.45
2048 512 24576 1.650 1240.85 13.425 38.14
2048 512 26624 1.692 1210.69 13.509 37.90
2048 512 28672 1.684 1215.95 13.621 37.59
2048 512 30720 1.708 1199.17 13.712 37.34
2048 512 32768 1.731 1182.89 13.888 36.87
2048 512 34816 1.733 1181.70 14.035 36.48
2048 512 36864 1.767 1159.29 14.189 36.08
2048 512 38912 1.755 1167.07 14.265 35.89
2048 512 40960 1.794 1141.69 14.360 35.65
2048 512 43008 1.801 1137.38 14.445 35.44
2048 512 45056 1.836 1115.40 14.517 35.27
2048 512 47104 1.837 1115.01 14.669 34.90
2048 512 49152 1.867 1096.90 14.869 34.43
2048 512 51200 1.857 1102.78 14.945 34.26
2048 512 53248 1.893 1082.11 15.090 33.93
2048 512 55296 1.888 1084.97 15.154 33.79
2048 512 57344 1.906 1074.43 15.263 33.55
2048 512 59392 1.930 1061.20 15.368 33.32
2048 512 61440 1.971 1039.16 15.468 33.10
2048 512 63488 1.971 1039.16 15.566 32.89
2048 512 65536 1.988 1030.08 15.744 32.52
2048 512 67584 2.012 1017.97 15.839 32.32
2048 512 69632 2.019 1014.19 15.963 32.07

without PR, -sm layer

Details
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 1.463 1400.22 11.952 42.84
2048 512 2048 1.453 1409.86 12.054 42.48
2048 512 4096 1.494 1371.24 12.351 41.46
2048 512 6144 1.536 1333.19 12.658 40.45
2048 512 8192 1.558 1314.77 12.988 39.42
2048 512 10240 1.584 1292.79 13.154 38.92
2048 512 12288 1.610 1271.88 13.311 38.46
2048 512 14336 1.637 1251.24 13.516 37.88
2048 512 16384 1.653 1239.09 13.787 37.14
2048 512 18432 1.691 1210.77 13.984 36.61
2048 512 20480 1.708 1199.20 14.161 36.16
2048 512 22528 1.732 1182.30 14.354 35.67
2048 512 24576 1.761 1162.78 14.629 35.00
2048 512 26624 1.790 1144.38 14.811 34.57
2048 512 28672 1.821 1124.35 15.013 34.10
2048 512 30720 1.844 1110.77 15.203 33.68
2048 512 32768 1.873 1093.68 15.488 33.06
2048 512 34816 1.903 1076.42 15.707 32.60
2048 512 36864 1.932 1060.29 15.858 32.29
2048 512 38912 1.964 1042.89 16.014 31.97
2048 512 40960 1.991 1028.54 16.325 31.36
2048 512 43008 2.027 1010.26 16.491 31.05
2048 512 45056 2.049 999.33 16.658 30.74
2048 512 47104 2.076 986.28 16.852 30.38

without PR, 1*3090

Details
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 1.457 1405.91 12.074 42.41
2048 512 2048 1.436 1426.65 12.176 42.05
2048 512 4096 1.456 1406.20 12.404 41.28
2048 512 6144 1.477 1386.53 12.637 40.52
2048 512 8192 1.503 1362.85 12.938 39.57
2048 512 10240 1.528 1339.99 13.117 39.03
2048 512 12288 1.556 1316.16 13.292 38.52
2048 512 14336 1.585 1292.32 13.487 37.96
2048 512 16384 1.605 1275.62 13.759 37.21
2048 512 18432 1.633 1253.88 13.948 36.71
2048 512 20480 1.658 1235.50 14.113 36.28
2048 512 22528 1.683 1217.02 14.299 35.81
2048 512 24576 1.711 1197.22 14.580 35.12
2048 512 26624 1.739 1177.93 14.759 34.69
2048 512 28672 1.760 1163.36 14.926 34.30
2048 512 30720 1.785 1147.38 15.115 33.87
2048 512 32768 1.814 1129.28 15.375 33.30
2048 512 34816 1.839 1113.82 15.589 32.84
2048 512 36864 1.868 1096.45 15.760 32.49
2048 512 38912 1.900 1077.83 15.933 32.13
2048 512 40960 1.933 1059.24 16.208 31.59
2048 512 43008 1.943 1054.23 16.407 31.21
2048 512 45056 1.972 1038.32 16.549 30.94
2048 512 47104 2.000 1023.83 16.732 30.60

with PR, -sm graph

Details
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 1.119 1830.57 8.745 58.55
2048 512 2048 1.096 1868.54 8.822 58.04
2048 512 4096 1.086 1885.90 9.028 56.71
2048 512 6144 1.204 1701.15 9.309 55.00
2048 512 8192 1.220 1679.02 9.623 53.20
2048 512 10240 1.331 1538.94 9.778 52.36
2048 512 12288 1.273 1609.22 9.932 51.55
2048 512 14336 1.278 1602.06 10.059 50.90
2048 512 16384 1.306 1568.59 10.295 49.73
2048 512 18432 1.310 1562.84 10.408 49.19
2048 512 20480 1.345 1523.22 10.510 48.72
2048 512 22528 1.383 1480.78 10.587 48.36
2048 512 24576 1.401 1462.05 10.667 48.00
2048 512 26624 1.391 1472.61 10.761 47.58
2048 512 28672 1.425 1436.90 10.877 47.07
2048 512 30720 1.440 1422.57 10.975 46.65
2048 512 32768 1.442 1419.76 11.165 45.86
2048 512 34816 1.461 1401.73 11.271 45.43
2048 512 36864 1.507 1359.13 11.408 44.88
2048 512 38912 1.498 1367.12 11.483 44.59
2048 512 40960 1.526 1341.92 11.560 44.29
2048 512 43008 1.527 1341.22 11.750 43.57
2048 512 45056 1.604 1276.48 11.756 43.55
2048 512 47104 1.596 1283.53 11.834 43.26
2048 512 49152 1.611 1271.44 12.098 42.32
2048 512 51200 1.630 1256.78 12.195 41.98
2048 512 53248 1.610 1271.78 12.292 41.65

with PR, -sm layer

Details
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 1.474 1389.25 11.927 42.93
2048 512 2048 1.468 1394.84 12.098 42.32
2048 512 4096 1.506 1360.12 12.353 41.45
2048 512 6144 1.528 1340.00 12.608 40.61
2048 512 8192 1.555 1316.75 12.926 39.61
2048 512 10240 1.579 1297.10 13.111 39.05
2048 512 12288 1.604 1276.66 13.277 38.56
2048 512 14336 1.626 1259.63 13.466 38.02
2048 512 16384 1.646 1243.87 13.720 37.32
2048 512 18432 1.677 1220.94 13.915 36.79
2048 512 20480 1.705 1200.87 14.090 36.34
2048 512 22528 1.733 1181.89 14.331 35.73
2048 512 24576 1.756 1166.40 14.603 35.06
2048 512 26624 1.790 1144.33 14.751 34.71
2048 512 28672 1.815 1128.29 14.938 34.28
2048 512 30720 1.847 1108.66 15.132 33.84
2048 512 32768 1.871 1094.55 15.448 33.14
2048 512 34816 1.895 1080.84 15.664 32.69
2048 512 36864 1.925 1063.69 15.830 32.34
2048 512 38912 1.954 1047.93 16.022 31.96
2048 512 40960 1.986 1031.28 16.299 31.41
2048 512 43008 2.011 1018.25 16.432 31.16
2048 512 45056 2.048 1000.14 16.690 30.68
2048 512 47104 2.074 987.56 16.883 30.33
2048 512 49152 2.105 973.14 17.144 29.86
2048 512 51200 2.127 962.92 17.334 29.54
2048 512 53248 2.160 947.97 17.514 29.23
2048 512 55296 2.181 939.20 17.715 28.90
2048 512 57344 2.206 928.18 17.923 28.57
2048 512 59392 2.233 917.17 18.104 28.28
2048 512 61440 2.262 905.51 18.282 28.01
2048 512 63488 2.286 896.04 18.484 27.70

with PR, 1*3090

Details
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 1.438 1423.72 12.026 42.57
2048 512 2048 1.437 1424.83 12.150 42.14
2048 512 4096 1.453 1409.97 12.428 41.20
2048 512 6144 1.469 1394.55 12.608 40.61
2048 512 8192 1.499 1366.44 12.901 39.69
2048 512 10240 1.523 1344.53 13.087 39.12
2048 512 12288 1.549 1322.24 13.262 38.61
2048 512 14336 1.580 1296.44 13.459 38.04
2048 512 16384 1.598 1281.41 13.758 37.22
2048 512 18432 1.634 1253.28 13.962 36.67
2048 512 20480 1.659 1234.65 14.138 36.21
2048 512 22528 1.687 1214.16 14.327 35.74
2048 512 24576 1.710 1197.58 14.607 35.05
2048 512 26624 1.737 1179.37 14.791 34.61
2048 512 28672 1.764 1161.17 14.957 34.23
2048 512 30720 1.786 1146.65 15.133 33.83
2048 512 32768 1.814 1129.30 15.400 33.25
2048 512 34816 1.842 1111.66 15.585 32.85
2048 512 36864 1.866 1097.25 15.756 32.50
2048 512 38912 1.900 1077.76 15.933 32.14
2048 512 40960 1.930 1060.99 16.213 31.58
2048 512 43008 1.944 1053.52 16.397 31.23
2048 512 45056 1.973 1038.23 16.566 30.91
2048 512 47104 1.999 1024.51 16.743 30.58
2048 512 49152 2.030 1008.70 17.013 30.10
2048 512 51200 2.057 995.67 17.200 29.77
2048 512 53248 2.086 981.70 17.393 29.44

@ikawrakow
Copy link
Owner Author

OK, another fix. That should not crash.

@ubergarm
Copy link
Contributor

ubergarm commented Mar 9, 2026

Qwen3.5-122B-A10B-IQ4_KSS.gguf agrees with you 😅 , and yes it fixes it!

$ git diff
diff --git a/src/llama.cpp b/src/llama.cpp
index d93f8818..84beb7d4 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -6717,7 +6717,7 @@ struct llama_data_write_buffer : llama_data_write {
             auto split = extra->splits[id];
             if (!split) continue;
             GGML_ASSERT(split->type == tensor->type);
-            auto split_row_size = ggml_row_size(tensor->type, tensor->ne[0]);
+            auto split_row_size = ggml_row_size(tensor->type, split->ne[0]);
             auto split_size = split_row_size * num_rows;
             if (split_size > aux_buffer.size()) aux_buffer.resize(split_size);
             ggml_backend_tensor_get(split, aux_buffer.data(), first_row*split_row_size, split_size);

@ikawrakow
Copy link
Owner Author

60 t/s for a 397B model is pretty good! Do we qualify to play in the Big Boys league now (vLLM, sglang) ?

@TomTheWise
Copy link

Hi just a question:
"Caveat 2: There is an issue when using vision with split mode graph. Hence, I have disabled split mode graph for now when --mmproj is present in the command line arguments."

Do you mean this issue between graph and vision just for qwen3.5 or in general?

Because I tested last week and just some minutes ago with the latest build version too, and graph even without NCCL was giving nice results, but I was completely unable to get graph and mmproj/vison to work at the same time with ANY model (also tried ministral-3, gemma3 with their mmproj).
Server starts, but then even when just sending text it fails with "/opt/ik_llama.cpp/src/llama-sampling.cpp:733: GGML_ASSERT(iter != probs.end()) failed" and after shutting the llama-server, nvidia-smi still shows the reserved VRAM despite no processes using them anymore.

Is this the same you mentioned or is this an unrelated / another thing?
This is an example with image, but the failure is nearly exact the same with just text. Without setting, mmproj it works flawlessly:
ministral-3_mmprojError.txt

@ikawrakow
Copy link
Owner Author

@TomTheWise The inability to use vision (or even just loading an mmproj file) with split mode graph is not related to this PR, it existed before. It is just that after receiving enough reports of it not working, I decided to just disable the usage of split mode graph when an mmproj file has been specified on the command line.

I'm looking into it, but so far cannot see the reason it does not work.

@Nexesenex
Copy link
Contributor

Nexesenex commented Mar 9, 2026

Broadly, I get on the Qwen 3.5 122B MOE in full offload, +20% in PP and +50% in TG vs previous graph implementation.
This is superb. Bravo, @ikawrakow!

Note: the +50% TG is partly justified by the P-State kick allow full frequencies for both the GPU and VRAM, due to a sufficient load. But side benefits are part of the overall benefits.

@magikRUKKOLA
Copy link

@ikawrakow

60 t/s for a 397B model is pretty good! Do we qualify to play in the Big Boys league now (vLLM, sglang) ?

Power consumption per GPU: 133W-295W. (AVG: ~195W)

prefill-qwen35

decode-qwen35

@magikRUKKOLA
Copy link

magikRUKKOLA commented Mar 9, 2026

@ikawrakow

Here is 4xRTX 6000 PRO Blackwells and SGLang with Qwen3.5-397B-A17B-NVFP4 :

https://gist.github.com/catid/87cca824963f17fe7479a0ed26221397#benchmark-results-this-machine

Per-request decode: 67.76 tok/s

Well, yeah, its about 7 tps better in decode, but the price is [x5] for their setup. :)

@ikawrakow
Copy link
Owner Author

@magikRUKKOLA

The RTX 6000 PRO Blackwell is quite a bit faster than the 3090. My guess is at least 2X for prefill and 1.5X for generation. Then there is the fact that overhead for 4 GPUs is much less than overhead for 8 GPUs. I think you should expect at least 90 t/s from ik_llama.cpp on that machine.

@hksdpc255
Copy link
Contributor

hksdpc255 commented Mar 10, 2026

I have no luck for running -sm graph on my 4*3090 system:

CUDA_VISIBLE_DEVICES=4,5,6,7 ./ik_llama.cpp-build-qwen35graph/llama-sweep-bench --split-mode layer --cache-type-k bf16 --cache-type-v bf16 --n-gpu-layers 999 --model Qwen3.5-27B-BF16.gguf --ctx-size 65536
PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.571 896.38 8.908 14.37
512 128 512 0.554 924.29 8.880 14.41
CUDA_VISIBLE_DEVICES=4,5,6,7 ./ik_llama.cpp-build-qwen35graph/llama-sweep-bench --split-mode graph --cache-type-k bf16 --cache-type-v bf16 --n-gpu-layers 999 --model Qwen3.5-27B-BF16.gguf --ctx-size 65536
Details
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
  Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
  Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
  Device 3: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
=============================== NCCL main communicator initialized
=============================== NCCL pair communicators for 4 GPUs initialized
CUDA0: using device CUDA0 - 23725 MiB free
CUDA1: using device CUDA1 - 23725 MiB free
CUDA2: using device CUDA2 - 23725 MiB free
CUDA3: using device CUDA3 - 23725 MiB free
llama_model_loader: loaded meta data with 40 key-value pairs and 851 tensors from Qwen3.5-27B-BF16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 0.600000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.5 27B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.5
llama_model_loader: - kv   7:                         general.size_label str              = 27B
llama_model_loader: - kv   8:                            general.license str              = apache-2.0
llama_model_loader: - kv   9:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.5-2...
llama_model_loader: - kv  10:                               general.tags arr[str,1]       = ["image-text-to-text"]
llama_model_loader: - kv  11:                         qwen35.block_count u32              = 64
llama_model_loader: - kv  12:                      qwen35.context_length u32              = 262144
llama_model_loader: - kv  13:                    qwen35.embedding_length u32              = 5120
llama_model_loader: - kv  14:                 qwen35.feed_forward_length u32              = 17408
llama_model_loader: - kv  15:                qwen35.attention.head_count u32              = 24
llama_model_loader: - kv  16:             qwen35.attention.head_count_kv u32              = 4
llama_model_loader: - kv  17:             qwen35.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  18:                      qwen35.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  19:    qwen35.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  20:                qwen35.attention.key_length u32              = 256
llama_model_loader: - kv  21:              qwen35.attention.value_length u32              = 256
llama_model_loader: - kv  22:                          general.file_type u32              = 32
llama_model_loader: - kv  23:                     qwen35.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  24:                      qwen35.ssm.state_size u32              = 128
llama_model_loader: - kv  25:                     qwen35.ssm.group_count u32              = 16
llama_model_loader: - kv  26:                  qwen35.ssm.time_step_rank u32              = 48
llama_model_loader: - kv  27:                      qwen35.ssm.inner_size u32              = 6144
llama_model_loader: - kv  28:             qwen35.full_attention_interval u32              = 4
llama_model_loader: - kv  29:                qwen35.rope.dimension_count u32              = 64
llama_model_loader: - kv  30:               general.quantization_version u32              = 2
llama_model_loader: - kv  31:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  32:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  33:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  34:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  35:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  36:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  37:            tokenizer.ggml.padding_token_id u32              = 248044
llama_model_loader: - kv  38:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  39:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - type  f32:  353 tensors
llama_model_loader: - type bf16:  498 tensors
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = qwen35
llm_load_print_meta: n_ctx_train      = 262144
llm_load_print_meta: n_embd           = 5120
llm_load_print_meta: n_layer          = 64
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 4
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_swa_pattern    = 1
llm_load_print_meta: n_embd_head_k    = 256
llm_load_print_meta: n_embd_head_v    = 256
llm_load_print_meta: n_gqa            = 6
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 17408
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 40
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 262144
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: mrope sections   = [11, 11, 10, 0]
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 6144
llm_load_print_meta: ssm_d_state      = 128
llm_load_print_meta: ssm_dt_rank      = 48
llm_load_print_meta: ssm_n_group      = 16
llm_load_print_meta: model type       = 27B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 26.896 B
llm_load_print_meta: model size       = 50.103 GiB (16.002 BPW) 
llm_load_print_meta: repeating layers = 45.366 GiB (16.002 BPW, 24.353 B parameters)
llm_load_print_meta: general.name     = Qwen3.5 27B
print_info: vocab type       = BPE
print_info: n_vocab          = 248320
print_info: n_merges         = 247587
print_info: BOS token        = 11 ','
print_info: EOS token        = 248046 '<|im_end|>'
print_info: EOT token        = 248046 '<|im_end|>'
print_info: PAD token        = 248044 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: FIM PRE token    = 248060 '<|fim_prefix|>'
print_info: FIM SUF token    = 248062 '<|fim_suffix|>'
print_info: FIM MID token    = 248061 '<|fim_middle|>'
print_info: FIM PAD token    = 248063 '<|fim_pad|>'
print_info: FIM REP token    = 248064 '<|repo_name|>'
print_info: FIM SEP token    = 248065 '<|file_sep|>'
print_info: EOG token        = 248044 '<|endoftext|>'
print_info: EOG token        = 248046 '<|im_end|>'
print_info: EOG token        = 248063 '<|fim_pad|>'
print_info: EOG token        = 248064 '<|repo_name|>'
print_info: EOG token        = 248065 '<|file_sep|>'
print_info: max token length = 256
llm_load_tensors: ggml ctx size =   13.69 MiB
================================ max_gpu = 0
Estimated model buffer size per device:
    Device 0:  11615.68 MiB
    Device 1:  11615.68 MiB
    Device 2:  11615.68 MiB
    Device 3:  11615.68 MiB
No tensors in buffer type CUDA0
No tensors in buffer type CUDA1
No tensors in buffer type CUDA2
llm_load_tensors: offloading 64 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 65/65 layers to GPU
llm_load_tensors:        CPU buffer size =  2425.00 MiB
llm_load_tensors: CUDA_Split buffer size = 46464.38 MiB
llm_load_tensors:      CUDA3 buffer size =  2425.02 MiB
.............................................................................................
llama_init_from_model: n_ctx         = 65536
llama_init_from_model: n_batch       = 2048
llama_init_from_model: n_ubatch      = 512
llama_init_from_model: flash_attn    = 1
llama_init_from_model: attn_max_b    = 0
llama_init_from_model: fused_moe     = 1
llama_init_from_model: grouped er    = 0
llama_init_from_model: fused_up_gate = 1
llama_init_from_model: fused_mmad    = 1
llama_init_from_model: rope_cache    = 0
llama_init_from_model: graph_reuse   = 1
llama_init_from_model: k_cache_hadam = 0
llama_init_from_model: split_mode_graph_scheduling = 0
llama_init_from_model: reduce_type   = f16
llama_init_from_model: sched_async   = 0
llama_init_from_model: ser           = -1, 0
llama_init_from_model: freq_base     = 10000000.0
llama_init_from_model: freq_scale    = 1
ggml_new_object: not enough space in the context's memory pool (needed 141680, available 141312)
trap      (core dumped) CUDA_VISIBLE_DEVICES=4,5,6,7 ./ik_llama.cpp-build-qwen35graph/llama-sweep-bench --split-mode graph --cache-type-k bf16 --cache-type-v bf16 --n-gpu-layers 999 --model Qwen3.5-27B-BF16.gguf --ctx-size 65536

Loads and starts running, crashes with illegal memory access in
quantize_mmq_q8_1. This almost always indicates NaNs in the input
to the MoE FFN part.
@ikawrakow ikawrakow force-pushed the ik/sm_graph_delta_net branch from 2405855 to f09d421 Compare March 10, 2026 06:04
@ikawrakow
Copy link
Owner Author

@hksdpc255 The latest commit should solve your issue of not being to run Qwen3.5 with 4 GPUs

@hksdpc255
Copy link
Contributor

@ikawrakow Thank you. I'm building the new branch

@hksdpc255
Copy link
Contributor

Still crash:

Details
$ CUDA_VISIBLE_DEVICES=4,5,6,7 ./ik_llama.cpp-build-qwen35graph/llama-sweep-bench --split-mode graph --cache-type-k bf16 --cache-type-v bf16 --n-gpu-layers 999 --model Qwen3.5-27B-BF16.gguf --ctx-size 65536
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
  Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
  Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
  Device 3: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: no, VRAM: 24126 MiB
=============================== NCCL main communicator initialized
=============================== NCCL pair communicators for 4 GPUs initialized
CUDA0: using device CUDA0 - 23725 MiB free
CUDA1: using device CUDA1 - 23725 MiB free
CUDA2: using device CUDA2 - 23725 MiB free
CUDA3: using device CUDA3 - 23725 MiB free
llama_model_loader: loaded meta data with 40 key-value pairs and 851 tensors from Qwen3.5-27B-BF16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen35
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                     general.sampling.top_k i32              = 20
llama_model_loader: - kv   3:                     general.sampling.top_p f32              = 0.950000
llama_model_loader: - kv   4:                      general.sampling.temp f32              = 0.600000
llama_model_loader: - kv   5:                               general.name str              = Qwen3.5 27B
llama_model_loader: - kv   6:                           general.basename str              = Qwen3.5
llama_model_loader: - kv   7:                         general.size_label str              = 27B
llama_model_loader: - kv   8:                            general.license str              = apache-2.0
llama_model_loader: - kv   9:                       general.license.link str              = https://huggingface.co/Qwen/Qwen3.5-2...
llama_model_loader: - kv  10:                               general.tags arr[str,1]       = ["image-text-to-text"]
llama_model_loader: - kv  11:                         qwen35.block_count u32              = 64
llama_model_loader: - kv  12:                      qwen35.context_length u32              = 262144
llama_model_loader: - kv  13:                    qwen35.embedding_length u32              = 5120
llama_model_loader: - kv  14:                 qwen35.feed_forward_length u32              = 17408
llama_model_loader: - kv  15:                qwen35.attention.head_count u32              = 24
llama_model_loader: - kv  16:             qwen35.attention.head_count_kv u32              = 4
llama_model_loader: - kv  17:             qwen35.rope.dimension_sections arr[i32,4]       = [11, 11, 10, 0]
llama_model_loader: - kv  18:                      qwen35.rope.freq_base f32              = 10000000.000000
llama_model_loader: - kv  19:    qwen35.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  20:                qwen35.attention.key_length u32              = 256
llama_model_loader: - kv  21:              qwen35.attention.value_length u32              = 256
llama_model_loader: - kv  22:                          general.file_type u32              = 32
llama_model_loader: - kv  23:                     qwen35.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  24:                      qwen35.ssm.state_size u32              = 128
llama_model_loader: - kv  25:                     qwen35.ssm.group_count u32              = 16
llama_model_loader: - kv  26:                  qwen35.ssm.time_step_rank u32              = 48
llama_model_loader: - kv  27:                      qwen35.ssm.inner_size u32              = 6144
llama_model_loader: - kv  28:             qwen35.full_attention_interval u32              = 4
llama_model_loader: - kv  29:                qwen35.rope.dimension_count u32              = 64
llama_model_loader: - kv  30:               general.quantization_version u32              = 2
llama_model_loader: - kv  31:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  32:                         tokenizer.ggml.pre str              = qwen35
llama_model_loader: - kv  33:                      tokenizer.ggml.tokens arr[str,248320]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  34:                  tokenizer.ggml.token_type arr[i32,248320]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  35:                      tokenizer.ggml.merges arr[str,247587]  = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv  36:                tokenizer.ggml.eos_token_id u32              = 248046
llama_model_loader: - kv  37:            tokenizer.ggml.padding_token_id u32              = 248044
llama_model_loader: - kv  38:               tokenizer.ggml.add_bos_token bool             = false
llama_model_loader: - kv  39:                    tokenizer.chat_template str              = {%- set image_count = namespace(value...
llama_model_loader: - type  f32:  353 tensors
llama_model_loader: - type bf16:  498 tensors
load: printing all EOG tokens:
load:   - 248044 ('<|endoftext|>')
load:   - 248046 ('<|im_end|>')
load:   - 248063 ('<|fim_pad|>')
load:   - 248064 ('<|repo_name|>')
load:   - 248065 ('<|file_sep|>')
load: special tokens cache size = 33
load: token to piece cache size = 1.7581 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = qwen35
llm_load_print_meta: n_ctx_train      = 262144
llm_load_print_meta: n_embd           = 5120
llm_load_print_meta: n_layer          = 64
llm_load_print_meta: n_head           = 24
llm_load_print_meta: n_head_kv        = 4
llm_load_print_meta: n_rot            = 64
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_swa_pattern    = 1
llm_load_print_meta: n_embd_head_k    = 256
llm_load_print_meta: n_embd_head_v    = 256
llm_load_print_meta: n_gqa            = 6
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 17408
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 40
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 262144
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: mrope sections   = [11, 11, 10, 0]
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 6144
llm_load_print_meta: ssm_d_state      = 128
llm_load_print_meta: ssm_dt_rank      = 48
llm_load_print_meta: ssm_n_group      = 16
llm_load_print_meta: model type       = 27B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 26.896 B
llm_load_print_meta: model size       = 50.103 GiB (16.002 BPW) 
llm_load_print_meta: repeating layers = 45.366 GiB (16.002 BPW, 24.353 B parameters)
llm_load_print_meta: general.name     = Qwen3.5 27B
print_info: vocab type       = BPE
print_info: n_vocab          = 248320
print_info: n_merges         = 247587
print_info: BOS token        = 11 ','
print_info: EOS token        = 248046 '<|im_end|>'
print_info: EOT token        = 248046 '<|im_end|>'
print_info: PAD token        = 248044 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: FIM PRE token    = 248060 '<|fim_prefix|>'
print_info: FIM SUF token    = 248062 '<|fim_suffix|>'
print_info: FIM MID token    = 248061 '<|fim_middle|>'
print_info: FIM PAD token    = 248063 '<|fim_pad|>'
print_info: FIM REP token    = 248064 '<|repo_name|>'
print_info: FIM SEP token    = 248065 '<|file_sep|>'
print_info: EOG token        = 248044 '<|endoftext|>'
print_info: EOG token        = 248046 '<|im_end|>'
print_info: EOG token        = 248063 '<|fim_pad|>'
print_info: EOG token        = 248064 '<|repo_name|>'
print_info: EOG token        = 248065 '<|file_sep|>'
print_info: max token length = 256
llm_load_tensors: ggml ctx size =   13.69 MiB
================================ max_gpu = 0
Estimated model buffer size per device:
    Device 0:  11615.68 MiB
    Device 1:  11615.68 MiB
    Device 2:  11615.68 MiB
    Device 3:  11615.68 MiB
No tensors in buffer type CUDA0
No tensors in buffer type CUDA1
No tensors in buffer type CUDA2
llm_load_tensors: offloading 64 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 65/65 layers to GPU
llm_load_tensors:        CPU buffer size =  2425.00 MiB
llm_load_tensors: CUDA_Split buffer size = 46464.38 MiB
llm_load_tensors:      CUDA3 buffer size =  2425.02 MiB
.............................................................................................
llama_init_from_model: n_ctx         = 65536
llama_init_from_model: n_batch       = 2048
llama_init_from_model: n_ubatch      = 512
llama_init_from_model: flash_attn    = 1
llama_init_from_model: attn_max_b    = 0
llama_init_from_model: fused_moe     = 1
llama_init_from_model: grouped er    = 0
llama_init_from_model: fused_up_gate = 1
llama_init_from_model: fused_mmad    = 1
llama_init_from_model: rope_cache    = 0
llama_init_from_model: graph_reuse   = 1
llama_init_from_model: k_cache_hadam = 0
llama_init_from_model: split_mode_graph_scheduling = 0
llama_init_from_model: reduce_type   = f16
llama_init_from_model: sched_async   = 0
llama_init_from_model: ser           = -1, 0
llama_init_from_model: freq_base     = 10000000.0
llama_init_from_model: freq_scale    = 1
ggml_new_object: not enough space in the context's memory pool (needed 141680, available 141312)
trap      (core dumped) CUDA_VISIBLE_DEVICES=4,5,6,7 ./ik_llama.cpp-build-qwen35graph/llama-sweep-bench --split-mode graph --cache-type-k bf16 --cache-type-v bf16 --n-gpu-layers 999 --model Qwen3.5-27B-BF16.gguf --ctx-size 65536

@ikawrakow
Copy link
Owner Author

@hksdpc255

Well, not sure what your issue is. I don'g have the bf16 version of Qwen3.5-27B, but I do have IQ4_NL. That works just fine with 4 GPUs. I then used llama-quantize to convert the IQ4_NL model to bf16. That runs fine as well. Here my run with the bf16 model (and bf16 KV cache, which lowers TG performance at long context) on 4x3090:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 128 0 0.829 2469.92 3.404 37.60
2048 128 2048 0.975 2100.52 3.544 36.11
2048 128 4096 0.777 2636.25 3.490 36.67
2048 128 6144 0.784 2611.62 3.510 36.47
2048 128 8192 0.796 2572.50 3.531 36.25
2048 128 10240 0.802 2552.11 3.550 36.05
2048 128 12288 0.812 2521.08 3.568 35.88
2048 128 14336 0.820 2497.28 3.593 35.62
2048 128 16384 0.829 2471.45 3.603 35.53
2048 128 18432 0.842 2433.06 3.653 35.04
2048 128 20480 0.847 2419.11 3.660 34.97
2048 128 22528 0.893 2292.54 3.684 34.74
2048 128 24576 0.863 2373.96 3.707 34.53
2048 128 26624 0.871 2351.04 3.730 34.32
2048 128 28672 0.879 2329.48 3.756 34.08
2048 128 30720 0.887 2308.32 3.788 33.79
2048 128 32768 0.897 2283.04 3.811 33.59
2048 128 34816 0.911 2248.67 3.841 33.33
2048 128 36864 0.917 2232.33 3.863 33.13
2048 128 38912 0.924 2215.75 3.884 32.95
2048 128 40960 0.938 2182.75 3.917 32.68
2048 128 43008 0.944 2170.47 3.911 32.73
2048 128 45056 0.953 2148.76 3.935 32.53
2048 128 47104 0.962 2129.11 3.945 32.45
2048 128 49152 0.977 2095.49 3.975 32.20
2048 128 51200 0.986 2076.82 4.001 31.99
2048 128 53248 0.991 2065.91 4.012 31.91
2048 128 55296 1.000 2048.86 4.022 31.83
2048 128 57344 1.049 1952.12 4.039 31.69
2048 128 59392 1.020 2008.17 4.051 31.59
2048 128 61440 1.030 1988.18 4.070 31.45
2048 128 63488 1.061 1930.69 4.101 31.22

@hksdpc255
Copy link
Contributor

hksdpc255 commented Mar 10, 2026

You're right, my hard disk is broken. I move the model and llama-sweep-bench to another disk and now its works.

Here's the BF16 performance for 4*3090

Details

main: n_kv_max = 65536, n_batch = 2048, n_ubatch = 512, flash_attn = 1, n_gpu_layers = 999, n_threads = 8, n_threads_batch = 8

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.578 886.37 3.518 36.39
512 128 512 0.567 902.51 3.649 35.08
512 128 1024 0.477 1072.80 3.551 36.05
512 128 1536 0.479 1067.79 3.556 36.00
512 128 2048 0.482 1063.16 3.560 35.95
512 128 2560 0.482 1061.57 3.566 35.90
512 128 3072 0.484 1058.00 3.571 35.85
512 128 3584 0.483 1059.84 3.576 35.79
512 128 4096 0.484 1057.88 3.605 35.51
512 128 4608 0.485 1056.52 3.609 35.47
512 128 5120 0.485 1056.10 3.611 35.45
512 128 5632 0.485 1054.75 3.617 35.39
512 128 6144 0.487 1050.73 3.627 35.30
512 128 6656 0.486 1053.98 3.625 35.31
512 128 7168 0.486 1054.22 3.630 35.26
512 128 7680 0.486 1053.36 3.641 35.16
512 128 8192 0.486 1052.46 3.647 35.10
512 128 8704 0.487 1051.50 3.646 35.10
512 128 9216 0.487 1051.70 3.651 35.06
512 128 9728 0.488 1049.66 3.660 34.97
512 128 10240 0.488 1048.59 3.660 34.97
512 128 10752 0.488 1048.61 3.664 34.93
512 128 11264 0.489 1046.44 3.671 34.87
512 128 11776 0.489 1046.19 3.677 34.81
512 128 12288 0.490 1044.37 3.678 34.80
512 128 12800 0.491 1042.81 3.687 34.72
512 128 13312 0.492 1040.65 3.691 34.68
512 128 13824 0.493 1039.13 3.700 34.60
512 128 14336 0.493 1037.84 3.699 34.60
512 128 14848 0.493 1038.21 3.709 34.51
512 128 15360 0.494 1036.04 3.714 34.47
512 128 15872 0.493 1038.62 3.721 34.40
512 128 16384 0.494 1036.18 3.731 34.31
512 128 16896 0.494 1036.26 3.736 34.26
512 128 17408 0.496 1031.78 3.742 34.21
512 128 17920 0.496 1031.75 3.748 34.15
512 128 18432 0.497 1030.64 3.755 34.09
512 128 18944 0.497 1031.08 3.760 34.05
512 128 19456 0.497 1030.17 3.766 33.99
512 128 19968 0.498 1028.03 3.772 33.93
512 128 20480 0.498 1027.78 3.778 33.88
512 128 20992 0.499 1027.02 3.785 33.82
512 128 21504 0.499 1025.64 3.789 33.78
512 128 22016 0.499 1025.20 3.796 33.72
512 128 22528 0.500 1024.15 3.799 33.69
512 128 23040 0.500 1024.37 3.805 33.64
512 128 23552 0.501 1021.26 3.811 33.59
512 128 24064 0.502 1020.67 3.817 33.53
512 128 24576 0.501 1021.09 3.821 33.50
512 128 25088 0.502 1019.31 3.827 33.45

update: I mixed up several benchmarks that were running simultaneously earlier. The data I’m posting now should be the correct one.

@ikawrakow ikawrakow merged commit f90b4c2 into main Mar 10, 2026
@hksdpc255
Copy link
Contributor

60 t/s for a 397B model is pretty good! Do we qualify to play in the Big Boys league now (vLLM, sglang) ?

sglang 0.5.9 seems currently broken for --tp 4 with rtx3090. In my setup the output is just 111111111....... Maybe related to sgl-project/sglang#19220 sgl-project/sglang#19411 sgl-project/sglang#19070 etc...

I don't know if the performance in this situation still meaningful.

CUDA_VISIBLE_DEVICES=4,5,6,7 python -m sglang.launch_server --model Qwen3.5-27B --tp-size 4 --mem-fraction-static 0.7 --context-length 262144 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --speculative-algo NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --max-running-requests 1 --disable-radix-cache
[2026-03-10 16:05:15 TP0] Decode batch, #running-req: 1, #full token: 40, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.18, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 7.11, #queue-req: 0
[2026-03-10 16:05:16 TP0] Decode batch, #running-req: 1, #full token: 106, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.65, accept rate: 0.41, cuda graph: True, gen throughput (token/s): 39.30, #queue-req: 0
[2026-03-10 16:05:18 TP0] Decode batch, #running-req: 1, #full token: 169, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.57, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 37.34, #queue-req: 0
[2026-03-10 16:05:20 TP0] Decode batch, #running-req: 1, #full token: 242, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.82, accept rate: 0.46, cuda graph: True, gen throughput (token/s): 43.07, #queue-req: 0
[2026-03-10 16:05:22 TP0] Decode batch, #running-req: 1, #full token: 318, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.90, accept rate: 0.47, cuda graph: True, gen throughput (token/s): 44.36, #queue-req: 0
[2026-03-10 16:05:23 TP0] Decode batch, #running-req: 1, #full token: 400, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.05, accept rate: 0.51, cuda graph: True, gen throughput (token/s): 47.96, #queue-req: 0
[2026-03-10 16:05:25 TP0] Decode batch, #running-req: 1, #full token: 484, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.10, accept rate: 0.53, cuda graph: True, gen throughput (token/s): 48.96, #queue-req: 0
[2026-03-10 16:05:27 TP0] Decode batch, #running-req: 1, #full token: 571, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.17, accept rate: 0.54, cuda graph: True, gen throughput (token/s): 50.78, #queue-req: 0
[2026-03-10 16:05:28 TP0] Decode batch, #running-req: 1, #full token: 658, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.17, accept rate: 0.54, cuda graph: True, gen throughput (token/s): 50.76, #queue-req: 0
[2026-03-10 16:05:30 TP0] Decode batch, #running-req: 1, #full token: 742, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.10, accept rate: 0.53, cuda graph: True, gen throughput (token/s): 49.26, #queue-req: 0
[2026-03-10 16:05:32 TP0] Decode batch, #running-req: 1, #full token: 831, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.23, accept rate: 0.56, cuda graph: True, gen throughput (token/s): 51.95, #queue-req: 0
[2026-03-10 16:05:34 TP0] Decode batch, #running-req: 1, #full token: 927, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.40, accept rate: 0.60, cuda graph: True, gen throughput (token/s): 56.11, #queue-req: 0
[2026-03-10 16:05:35 TP0] Decode batch, #running-req: 1, #full token: 1033, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.65, accept rate: 0.66, cuda graph: True, gen throughput (token/s): 62.55, #queue-req: 0
[2026-03-10 16:05:37 TP0] Decode batch, #running-req: 1, #full token: 1127, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.35, accept rate: 0.59, cuda graph: True, gen throughput (token/s): 55.71, #queue-req: 0
[2026-03-10 16:05:39 TP0] Decode batch, #running-req: 1, #full token: 1234, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.67, accept rate: 0.67, cuda graph: True, gen throughput (token/s): 63.67, #queue-req: 0
[2026-03-10 16:05:40 TP0] Decode batch, #running-req: 1, #full token: 1331, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.42, accept rate: 0.61, cuda graph: True, gen throughput (token/s): 57.62, #queue-req: 0
[2026-03-10 16:05:42 TP0] Decode batch, #running-req: 1, #full token: 1431, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.50, accept rate: 0.62, cuda graph: True, gen throughput (token/s): 59.83, #queue-req: 0
[2026-03-10 16:05:44 TP0] Decode batch, #running-req: 1, #full token: 1525, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.35, accept rate: 0.59, cuda graph: True, gen throughput (token/s): 55.85, #queue-req: 0
[2026-03-10 16:05:45 TP0] Decode batch, #running-req: 1, #full token: 1621, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.40, accept rate: 0.60, cuda graph: True, gen throughput (token/s): 56.35, #queue-req: 0
[2026-03-10 16:05:47 TP0] Decode batch, #running-req: 1, #full token: 1712, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.27, accept rate: 0.57, cuda graph: True, gen throughput (token/s): 53.34, #queue-req: 0
[2026-03-10 16:05:49 TP0] Decode batch, #running-req: 1, #full token: 1818, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.65, accept rate: 0.66, cuda graph: True, gen throughput (token/s): 62.67, #queue-req: 0
[2026-03-10 16:05:50 TP0] Decode batch, #running-req: 1, #full token: 1924, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.65, accept rate: 0.66, cuda graph: True, gen throughput (token/s): 62.75, #queue-req: 0
[2026-03-10 16:05:52 TP0] Decode batch, #running-req: 1, #full token: 2012, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.20, accept rate: 0.55, cuda graph: True, gen throughput (token/s): 51.88, #queue-req: 0
[2026-03-10 16:05:54 TP0] Decode batch, #running-req: 1, #full token: 2103, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.27, accept rate: 0.57, cuda graph: True, gen throughput (token/s): 54.16, #queue-req: 0
[2026-03-10 16:05:56 TP0] Decode batch, #running-req: 1, #full token: 2189, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.15, accept rate: 0.54, cuda graph: True, gen throughput (token/s): 51.24, #queue-req: 0
[2026-03-10 16:05:57 TP0] Decode batch, #running-req: 1, #full token: 2278, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.23, accept rate: 0.56, cuda graph: True, gen throughput (token/s): 52.78, #queue-req: 0
[2026-03-10 16:05:59 TP0] Decode batch, #running-req: 1, #full token: 2365, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.17, accept rate: 0.54, cuda graph: True, gen throughput (token/s): 51.57, #queue-req: 0
[2026-03-10 16:06:01 TP0] Decode batch, #running-req: 1, #full token: 2452, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.17, accept rate: 0.54, cuda graph: True, gen throughput (token/s): 51.51, #queue-req: 0
[2026-03-10 16:06:02 TP0] Decode batch, #running-req: 1, #full token: 2536, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.10, accept rate: 0.53, cuda graph: True, gen throughput (token/s): 50.22, #queue-req: 0
[2026-03-10 16:06:04 TP0] Decode batch, #running-req: 1, #full token: 2619, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.08, accept rate: 0.52, cuda graph: True, gen throughput (token/s): 49.34, #queue-req: 0
[2026-03-10 16:06:06 TP0] Decode batch, #running-req: 1, #full token: 2710, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.27, accept rate: 0.57, cuda graph: True, gen throughput (token/s): 54.28, #queue-req: 0
[2026-03-10 16:06:07 TP0] Decode batch, #running-req: 1, #full token: 2811, full token usage: 0.02, mamba num: 1, mamba usage: 1.00, accept len: 2.52, accept rate: 0.63, cuda graph: True, gen throughput (token/s): 59.98, #queue-req: 0
[2026-03-10 16:06:09 TP0] Decode batch, #running-req: 1, #full token: 2896, full token usage: 0.02, mamba num: 1, mamba usage: 1.00, accept len: 2.12, accept rate: 0.53, cuda graph: True, gen throughput (token/s): 50.43, #queue-req: 0
[2026-03-10 16:06:11 TP0] Decode batch, #running-req: 1, #full token: 2990, full token usage: 0.02, mamba num: 1, mamba usage: 1.00, accept len: 2.35, accept rate: 0.59, cuda graph: True, gen throughput (token/s): 55.83, #queue-req: 0
CUDA_VISIBLE_DEVICES=4,5,6,7 python -m sglang.launch_server --model Qwen3.5-27B --tp-size 4 --mem-fraction-static 0.7 --context-length 262144 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --max-running-requests 1 --disable-radix-cache
[2026-03-10 16:10:57 TP0] Decode batch, #running-req: 1, #full token: 38, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 0.20, #queue-req: 0
[2026-03-10 16:10:58 TP0] Decode batch, #running-req: 1, #full token: 78, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.26, #queue-req: 0
[2026-03-10 16:10:59 TP0] Decode batch, #running-req: 1, #full token: 118, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.22, #queue-req: 0
[2026-03-10 16:10:59 TP0] Decode batch, #running-req: 1, #full token: 158, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.19, #queue-req: 0
[2026-03-10 16:11:00 TP0] Decode batch, #running-req: 1, #full token: 198, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.18, #queue-req: 0
[2026-03-10 16:11:01 TP0] Decode batch, #running-req: 1, #full token: 238, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.18, #queue-req: 0
[2026-03-10 16:11:02 TP0] Decode batch, #running-req: 1, #full token: 278, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.16, #queue-req: 0
[2026-03-10 16:11:03 TP0] Decode batch, #running-req: 1, #full token: 318, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.15, #queue-req: 0
[2026-03-10 16:11:04 TP0] Decode batch, #running-req: 1, #full token: 358, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.13, #queue-req: 0
[2026-03-10 16:11:05 TP0] Decode batch, #running-req: 1, #full token: 398, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.10, #queue-req: 0
[2026-03-10 16:11:06 TP0] Decode batch, #running-req: 1, #full token: 438, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.10, #queue-req: 0
[2026-03-10 16:11:06 TP0] Decode batch, #running-req: 1, #full token: 478, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.10, #queue-req: 0
[2026-03-10 16:11:07 TP0] Decode batch, #running-req: 1, #full token: 518, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.11, #queue-req: 0
[2026-03-10 16:11:08 TP0] Decode batch, #running-req: 1, #full token: 558, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.11, #queue-req: 0
[2026-03-10 16:11:09 TP0] Decode batch, #running-req: 1, #full token: 598, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.10, #queue-req: 0
[2026-03-10 16:11:10 TP0] Decode batch, #running-req: 1, #full token: 638, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.08, #queue-req: 0
[2026-03-10 16:11:11 TP0] Decode batch, #running-req: 1, #full token: 678, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.09, #queue-req: 0
[2026-03-10 16:11:12 TP0] Decode batch, #running-req: 1, #full token: 718, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.05, #queue-req: 0
[2026-03-10 16:11:12 TP0] Decode batch, #running-req: 1, #full token: 758, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.05, #queue-req: 0
[2026-03-10 16:11:13 TP0] Decode batch, #running-req: 1, #full token: 798, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.01, #queue-req: 0
[2026-03-10 16:11:14 TP0] Decode batch, #running-req: 1, #full token: 838, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.01, #queue-req: 0
[2026-03-10 16:11:15 TP0] Decode batch, #running-req: 1, #full token: 878, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.01, #queue-req: 0
[2026-03-10 16:11:16 TP0] Decode batch, #running-req: 1, #full token: 918, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.01, #queue-req: 0
[2026-03-10 16:11:17 TP0] Decode batch, #running-req: 1, #full token: 958, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.01, #queue-req: 0
[2026-03-10 16:11:18 TP0] Decode batch, #running-req: 1, #full token: 998, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.01, #queue-req: 0
[2026-03-10 16:11:19 TP0] Decode batch, #running-req: 1, #full token: 1038, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.01, #queue-req: 0
[2026-03-10 16:11:19 TP0] Decode batch, #running-req: 1, #full token: 1078, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.99, #queue-req: 0
[2026-03-10 16:11:20 TP0] Decode batch, #running-req: 1, #full token: 1118, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 46.00, #queue-req: 0
[2026-03-10 16:11:21 TP0] Decode batch, #running-req: 1, #full token: 1158, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.99, #queue-req: 0
[2026-03-10 16:11:22 TP0] Decode batch, #running-req: 1, #full token: 1198, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.98, #queue-req: 0
[2026-03-10 16:11:23 TP0] Decode batch, #running-req: 1, #full token: 1238, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.97, #queue-req: 0
[2026-03-10 16:11:24 TP0] Decode batch, #running-req: 1, #full token: 1278, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.98, #queue-req: 0
[2026-03-10 16:11:25 TP0] Decode batch, #running-req: 1, #full token: 1318, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.98, #queue-req: 0
[2026-03-10 16:11:26 TP0] Decode batch, #running-req: 1, #full token: 1358, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.97, #queue-req: 0
[2026-03-10 16:11:26 TP0] Decode batch, #running-req: 1, #full token: 1398, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.95, #queue-req: 0
[2026-03-10 16:11:27 TP0] Decode batch, #running-req: 1, #full token: 1438, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.96, #queue-req: 0
[2026-03-10 16:11:28 TP0] Decode batch, #running-req: 1, #full token: 1478, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.96, #queue-req: 0
[2026-03-10 16:11:29 TP0] Decode batch, #running-req: 1, #full token: 1518, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.96, #queue-req: 0
[2026-03-10 16:11:30 TP0] Decode batch, #running-req: 1, #full token: 1558, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.95, #queue-req: 0
[2026-03-10 16:11:31 TP0] Decode batch, #running-req: 1, #full token: 1598, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.95, #queue-req: 0
[2026-03-10 16:11:32 TP0] Decode batch, #running-req: 1, #full token: 1638, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.95, #queue-req: 0
[2026-03-10 16:11:32 TP0] Decode batch, #running-req: 1, #full token: 1678, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.95, #queue-req: 0
[2026-03-10 16:11:33 TP0] Decode batch, #running-req: 1, #full token: 1718, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.94, #queue-req: 0
[2026-03-10 16:11:34 TP0] Decode batch, #running-req: 1, #full token: 1758, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.96, #queue-req: 0
[2026-03-10 16:11:35 TP0] Decode batch, #running-req: 1, #full token: 1798, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.93, #queue-req: 0
[2026-03-10 16:11:36 TP0] Decode batch, #running-req: 1, #full token: 1838, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.96, #queue-req: 0
[2026-03-10 16:11:37 TP0] Decode batch, #running-req: 1, #full token: 1878, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.95, #queue-req: 0
[2026-03-10 16:11:38 TP0] Decode batch, #running-req: 1, #full token: 1918, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.96, #queue-req: 0
[2026-03-10 16:11:39 TP0] Decode batch, #running-req: 1, #full token: 1958, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.97, #queue-req: 0
[2026-03-10 16:11:39 TP0] Decode batch, #running-req: 1, #full token: 1998, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.99, #queue-req: 0
[2026-03-10 16:11:40 TP0] Decode batch, #running-req: 1, #full token: 2038, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, cuda graph: True, gen throughput (token/s): 45.97, #queue-req: 0

ik_llama.cpp beats sglang when MTP(speculative) is disabled! But we still slower when MTP is enabled. Speculative decode accept rate around 0.6 seems normal compared to my another sglang deploy (using fp8 quant, fit in 48GB single GPU), the accept rate is usually 0.7 when I give a hard math problem.

@magikRUKKOLA
Copy link

magikRUKKOLA commented Mar 10, 2026

@ubergarm

Looking great on 2xA6000 GPUs:

The 4x3090 (PCIe 4.0 x16) gives such a similar result to yours I am not going to even post them.

[EDIT]:

Actually it runs happily with only three GPU. Will post the results later on.

Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
@Ph0rk0z
Copy link

Ph0rk0z commented Mar 10, 2026

Is it working with Q8 cache? Still crashes for me on qwen 397b.

Also since 344688c devstral prompt speed cut in half. My branch at that commit is almost 700t/s and head is 360t/s PP.





============ Repacked 108 tensors
llama_init_from_model: n_ctx         = 65536
llama_init_from_model: n_batch       = 2048
llama_init_from_model: n_ubatch      = 2048
llama_init_from_model: flash_attn    = 1
llama_init_from_model: attn_max_b    = 0
llama_init_from_model: fused_moe     = 1
llama_init_from_model: grouped er    = 1
llama_init_from_model: fused_up_gate = 1
llama_init_from_model: fused_mmad    = 1
llama_init_from_model: rope_cache    = 0
llama_init_from_model: graph_reuse   = 1
llama_init_from_model: k_cache_hadam = 0
llama_init_from_model: split_mode_graph_scheduling = 0
llama_init_from_model: reduce_type   = bf16
llama_init_from_model: sched_async   = 0
llama_init_from_model: ser           = -1, 0
llama_init_from_model: freq_base     = 10000000.0
llama_init_from_model: freq_scale    = 1
llama_init_from_model: cuda_params   = enable-p2p=1,fusion=1
ggml_backend_cuda_context: a context for device 0 already exists?
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 0->1
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 0->2
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 0->3
ggml_backend_cuda_context: a context for device 1 already exists?
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->0
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->2
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->3
ggml_backend_cuda_context: a context for device 2 already exists?
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 2->0
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 2->1
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 2->3
ggml_backend_cuda_context: a context for device 3 already exists?
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 3->0
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 3->1
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 3->2
llama_kv_cache_init: CUDA_Split KV buffer size =  1206.34 MiB
llama_kv_cache_init: KV cache size per device:
    Device 0:  284.841 MiB
    Device 1:  284.841 MiB
    Device 2:  318.323 MiB
    Device 3:  318.323 MiB
llama_init_from_model: KV self size  = 1020.00 MiB, K (q8_0):  510.00 MiB, V (q8_0):  510.00 MiB
llama_init_from_model:  CUDA_Host  output buffer size =     0.95 MiB
llama_init_from_model:      CUDA0 compute buffer size =   640.01 MiB
llama_init_from_model:      CUDA1 compute buffer size =   720.01 MiB
llama_init_from_model:      CUDA2 compute buffer size =   736.01 MiB
llama_init_from_model:      CUDA3 compute buffer size =  1988.00 MiB
llama_init_from_model:  CUDA_Host compute buffer size =   576.05 MiB
llama_init_from_model: graph nodes  = 13386
llama_init_from_model: graph splits = 681
llama_init_from_model: enabling only_active_experts scheduling

main: n_kv_max = 65536, n_batch = 2048, n_ubatch = 2048, flash_attn = 1, n_gpu_layers = 61, n_threads = 48, n_threads_batch = 48

|    PP |     TG |   N_KV |   T_PP s | S_PP t/s |   T_TG s | S_TG t/s |
|-------|--------|--------|----------|----------|----------|----------|
CUDA error: invalid resource handle
  current device: 2, in function ggml_cuda_op_reduce at /home/supermicro/ai/ik_llama.cpp/ggml/src/ggml-cuda/reduce.cu:448
  cudaEventRecord(info.all_ctx[i]->copy_event, info.all_ctx[i]->stream())
/home/supermicro/ai/ik_llama.cpp/ggml/src/ggml-cuda.cu:132: CUDA error

@ikawrakow
Copy link
Owner Author

ikawrakow commented Mar 10, 2026

@Ph0rk0z I have put a spell on it. It does not work for ST users most of the time. When it does work, it works at half the performance.

Here is what a non-ST user such as myself gets on the latest main branch for Devstral-123B-IQ4_KSS

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 128 0 2.137 958.16 4.319 29.63
2048 128 2048 2.074 987.44 4.376 29.25
2048 128 4096 2.136 958.70 4.496 28.47
2048 128 6144 2.207 927.81 4.550 28.13
2048 128 8192 2.266 903.98 4.567 28.03
2048 128 10240 2.334 877.57 4.649 27.53
2048 128 12288 2.398 853.93 4.669 27.41
2048 128 14336 2.459 832.97 4.715 27.15
2048 128 16384 2.519 812.96 4.815 26.58
2048 128 18432 2.583 792.97 4.811 26.60
2048 128 20480 2.647 773.83 4.858 26.35
2048 128 22528 2.706 756.71 4.907 26.08
2048 128 24576 2.771 739.15 4.916 26.04
2048 128 26624 2.836 722.04 4.920 26.02
2048 128 28672 2.897 706.98 4.952 25.85
2048 128 30720 2.968 689.95 4.970 25.76

@ikawrakow
Copy link
Owner Author

./bin/llama-cli -m /opt/ubergarm/Qwen3.5-397B-A17B-GGUF/IQ2_KL/Qwen3.5-397B-A17B-IQ2_KL-00001-of-00004.gguf \
   -c 32768 -ub 2048 -n 8192 -t 1 -ngl 100 -sm graph -ctk q8_0 -ctv q8_0 -p " " -cnv

> Have you heard of Silly Tavern?
formatted: <|im_start|>user
Have you heard of Silly Tavern?
<|im_end|>
<|im_start|>assistant

<think>

</think>

Yes, I'm familiar with **Silly Tavern**! It's a popular open-source frontend interface designed for interacting with large language models (LLMs), particularly in the context of roleplay, character chat, and creative storytelling. Key features include:

- **Character Cards**: Support for importing/exporting characters via formats like PNG (with embedded data) or JSON, often used in communities like those around KoboldAI, Ooba, or LLM backends.
- **Customization**: Highly configurable UI themes, extensions, and plugins to tailor the experience.
- **Backend Flexibility**: Connects to various LLM backends (e.g., KoboldAI, Ooba, Text Generation WebUI, OpenAI-compatible APIs, etc.).
- **Community-Driven**: Widely adopted in AI roleplay circles for its user-friendly design and active development.

It's often paired with local LLM setups but can also work with cloud-based APIs. Are you using it for a specific project, or would you like help with setup, troubleshooting, or features? 😊formatted: <|im_start|>assistant

<think>

</think>

Yes, I'm familiar with **Silly Tavern**! It's a popular open-source frontend interface designed for interacting with large language models (LLMs), particularly in the context of roleplay, character chat, and creative storytelling. Key features include:

- **Character Cards**: Support for importing/exporting characters via formats like PNG (with embedded data) or JSON, often used in communities like those around KoboldAI, Ooba, or LLM backends.
- **Customization**: Highly configurable UI themes, extensions, and plugins to tailor the experience.
- **Backend Flexibility**: Connects to various LLM backends (e.g., KoboldAI, Ooba, Text Generation WebUI, OpenAI-compatible APIs, etc.).
- **Community-Driven**: Widely adopted in AI roleplay circles for its user-friendly design and active development.

It's often paired with local LLM setups but can also work with cloud-based APIs. Are you using it for a specific project, or would you like help with setup, troubleshooting, or features? 😊<|im_end|>



> Can you tell me why one shouldn't use it?
formatted: 
<|im_start|>user
Can you tell me why one shouldn't use it?
<|im_end|>
<|im_start|>assistant

<think>

</think>

While **Silly Tavern** is a powerful and beloved tool for many, it isn't the right fit for everyone. There are several valid reasons why someone might choose *not* to use it, ranging from technical barriers to specific use-case mismatches:

### 1. **It Requires a Separate Backend**
Silly Tavern is **only a frontend**. It does not run the AI model itself.
- **The Issue**: You must already have a running LLM backend (like KoboldCPP, Ooba/Text Generation WebUI, LM Studio, or an API key from OpenAI/Anthropic).
- **Why avoid it?**: If you are looking for an "all-in-one" solution where you just install one program and start chatting immediately, Silly Tavern will be frustrating. You have to configure CORS settings, manage API endpoints, and ensure your backend is stable before Silly Tavern can even connect.

### 2. **Steep Learning Curve & Complexity**
The interface is packed with features, which can be overwhelming for beginners.
- **The Issue**: Menus for "World Info," "Author's Note," "Jailbreaks," "Instruct Templates," and complex generation parameters are exposed by default.
- **Why avoid it?**: If you want a simple, clean chat experience (like ChatGPT or Claude.ai), Silly Tavern feels like the cockpit of a spaceship. The sheer number of toggles can lead to "parameter paralysis" or accidental misconfiguration that breaks the roleplay.

### 3. **Local Resource Requirements**
While the frontend itself is lightweight (running in a browser), the typical use case involves running large models locally.
- **The Issue**: To get the best experience without paying for APIs, users often run 7B–70B+ parameter models locally.
- **Why avoid it?**: If you do not have a powerful GPU (NVIDIA with significant VRAM) or a Mac with substantial unified memory, your local setup will be slow or impossible. If you rely on paid APIs instead, the costs can add up quickly compared to a flat-rate subscription service.

### 4. **Security Risks (Especially with Extensions)**
Silly Tavern supports a robust extension ecosystem, but this comes with risks.
- **The Issue**: Many extensions involve executing custom JavaScript or connecting to third-party services.
- **Why avoid it?**: If you are not technically savvy, installing unverified extensions could expose your system to security vulnerabilities or data leaks. Additionally, exposing your local Silly Tavern instance to the internet (for remote access) requires careful network configuration to avoid unauthorized access to your chat logs and API keys.

### 5. **Not Suitable for Professional/Productivity Work**
Silly Tavern is heavily optimized for **roleplay, creative writing, and character interaction**.
- **The Issue**: Its features (like character cards, lorebooks, and roleplay-specific prompts) are overkill for tasks like coding assistance, data analysis, or professional email drafting.
- **Why avoid it?**: For serious productivity, dedicated interfaces like the official Claude/OpenAI web UIs, or enterprise-focused tools, offer better context management, citation features, and a more professional environment.

### 6. **Stability and "Bleeding Edge" Issues**
As an open-source project driven by community contributions, it can sometimes be unstable.
- **The Issue**: Updates might introduce bugs, break compatibility with certain backends, or change UI layouts unexpectedly.
- **Why avoid it?**: If you need a rock-solid, guaranteed-stable environment for critical workflows, the rapid iteration cycle of Silly Tavern might be too risky compared to polished commercial products.

### Summary
You should probably **avoid** Silly Tavern if:
- You want a simple, install-and-go application.
- You don't have the hardware or knowledge to set up an LLM backend.
- You need a tool strictly for professional coding or business tasks.
- You are uncomfortable configuring network settings (CORS, ports, APIs).

However, if you are an enthusiast who wants deep control over how the AI behaves, loves customizing characters, and doesn't mind some technical setup, it remains one of the best tools available.formatted: <|im_start|>assistant

<think>

</think>

While **Silly Tavern** is a powerful and beloved tool for many, it isn't the right fit for everyone. There are several valid reasons why someone might choose *not* to use it, ranging from technical barriers to specific use-case mismatches:

### 1. **It Requires a Separate Backend**
Silly Tavern is **only a frontend**. It does not run the AI model itself.
- **The Issue**: You must already have a running LLM backend (like KoboldCPP, Ooba/Text Generation WebUI, LM Studio, or an API key from OpenAI/Anthropic).
- **Why avoid it?**: If you are looking for an "all-in-one" solution where you just install one program and start chatting immediately, Silly Tavern will be frustrating. You have to configure CORS settings, manage API endpoints, and ensure your backend is stable before Silly Tavern can even connect.

### 2. **Steep Learning Curve & Complexity**
The interface is packed with features, which can be overwhelming for beginners.
- **The Issue**: Menus for "World Info," "Author's Note," "Jailbreaks," "Instruct Templates," and complex generation parameters are exposed by default.
- **Why avoid it?**: If you want a simple, clean chat experience (like ChatGPT or Claude.ai), Silly Tavern feels like the cockpit of a spaceship. The sheer number of toggles can lead to "parameter paralysis" or accidental misconfiguration that breaks the roleplay.

### 3. **Local Resource Requirements**
While the frontend itself is lightweight (running in a browser), the typical use case involves running large models locally.
- **The Issue**: To get the best experience without paying for APIs, users often run 7B–70B+ parameter models locally.
- **Why avoid it?**: If you do not have a powerful GPU (NVIDIA with significant VRAM) or a Mac with substantial unified memory, your local setup will be slow or impossible. If you rely on paid APIs instead, the costs can add up quickly compared to a flat-rate subscription service.

### 4. **Security Risks (Especially with Extensions)**
Silly Tavern supports a robust extension ecosystem, but this comes with risks.
- **The Issue**: Many extensions involve executing custom JavaScript or connecting to third-party services.
- **Why avoid it?**: If you are not technically savvy, installing unverified extensions could expose your system to security vulnerabilities or data leaks. Additionally, exposing your local Silly Tavern instance to the internet (for remote access) requires careful network configuration to avoid unauthorized access to your chat logs and API keys.

### 5. **Not Suitable for Professional/Productivity Work**
Silly Tavern is heavily optimized for **roleplay, creative writing, and character interaction**.
- **The Issue**: Its features (like character cards, lorebooks, and roleplay-specific prompts) are overkill for tasks like coding assistance, data analysis, or professional email drafting.
- **Why avoid it?**: For serious productivity, dedicated interfaces like the official Claude/OpenAI web UIs, or enterprise-focused tools, offer better context management, citation features, and a more professional environment.

### 6. **Stability and "Bleeding Edge" Issues**
As an open-source project driven by community contributions, it can sometimes be unstable.
- **The Issue**: Updates might introduce bugs, break compatibility with certain backends, or change UI layouts unexpectedly.
- **Why avoid it?**: If you need a rock-solid, guaranteed-stable environment for critical workflows, the rapid iteration cycle of Silly Tavern might be too risky compared to polished commercial products.

### Summary
You should probably **avoid** Silly Tavern if:
- You want a simple, install-and-go application.
- You don't have the hardware or knowledge to set up an LLM backend.
- You need a tool strictly for professional coding or business tasks.
- You are uncomfortable configuring network settings (CORS, ports, APIs).

So much about Qwen3.5-397B-A17B not working with Q8_0 KV cache.

@hksdpc255
Copy link
Contributor

hksdpc255 commented Mar 10, 2026

Qwen3.5-9B BF16 performance for 4*3090

Details

main: n_kv_max = 65536, n_batch = 2048, n_ubatch = 512, flash_attn = 1, n_gpu_layers = 999, n_threads = 8, n_threads_batch = 8

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.254 2013.67 1.425 89.84
512 128 512 0.243 2108.38 1.484 86.23
512 128 1024 0.197 2603.49 1.445 88.56
512 128 1536 0.197 2595.18 1.451 88.23
512 128 2048 0.197 2598.43 1.457 87.87
512 128 2560 0.197 2595.27 1.459 87.73
512 128 3072 0.198 2591.26 1.463 87.50
512 128 3584 0.198 2589.73 1.465 87.37
512 128 4096 0.198 2585.28 1.468 87.19
512 128 4608 0.198 2586.97 1.470 87.08
512 128 5120 0.199 2577.48 1.474 86.84
512 128 5632 0.199 2574.57 1.478 86.63
512 128 6144 0.200 2559.46 1.480 86.49
512 128 6656 0.200 2565.05 1.483 86.31
512 128 7168 0.200 2560.13 1.486 86.13
512 128 7680 0.200 2561.05 1.488 86.00
512 128 8192 0.200 2556.96 1.490 85.88
512 128 8704 0.200 2553.86 1.494 85.67
512 128 9216 0.201 2552.11 1.496 85.54
512 128 9728 0.201 2548.15 1.501 85.27
512 128 10240 0.201 2543.19 1.503 85.14
512 128 10752 0.202 2530.19 1.507 84.95
512 128 11264 0.202 2538.65 1.510 84.76
512 128 11776 0.202 2535.23 1.514 84.55
512 128 12288 0.202 2532.68 1.517 84.40
512 128 12800 0.203 2526.20 1.519 84.24
512 128 13312 0.203 2526.39 1.523 84.05
512 128 13824 0.203 2524.54 1.528 83.79
512 128 14336 0.203 2519.28 1.529 83.69
512 128 14848 0.203 2517.06 1.533 83.51
512 128 15360 0.204 2512.16 1.535 83.37
512 128 15872 0.204 2510.59 1.537 83.31
512 128 16384 0.204 2506.47 1.540 83.11
512 128 16896 0.205 2501.65 1.542 82.98
512 128 17408 0.205 2502.24 1.546 82.79
512 128 17920 0.205 2499.38 1.549 82.64
512 128 18432 0.205 2500.61 1.552 82.46
512 128 18944 0.206 2490.58 1.555 82.29
512 128 19456 0.206 2487.49 1.558 82.14
512 128 19968 0.206 2486.39 1.557 82.19
512 128 20480 0.206 2483.93 1.560 82.06
512 128 20992 0.206 2480.60 1.572 81.45
512 128 21504 0.207 2474.65 1.578 81.12
512 128 22016 0.207 2470.61 1.580 81.00
512 128 22528 0.207 2468.84 1.582 80.89
512 128 23040 0.207 2470.34 1.584 80.80
512 128 23552 0.208 2465.57 1.586 80.68
512 128 24064 0.208 2465.11 1.589 80.58
512 128 24576 0.208 2462.78 1.590 80.49

sglang with speculative enabled 4*3090

Details
[2026-03-10 21:46:12 TP0] Prefill batch, #new-seq: 1, #new-token: 19, #cached-token: 0, full token usage: 0.00, mamba usage: 1.00, #running-req: 0, #queue-req: 0, input throughput (token/s): 10.47, cuda graph: False
[2026-03-10 21:46:12 TP0] Decode batch, #running-req: 1, #full token: 52, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.00, accept rate: 0.25, cuda graph: True, gen throughput (token/s): 1.73, #queue-req: 0
[2026-03-10 21:46:13 TP0] Decode batch, #running-req: 1, #full token: 105, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.79, #queue-req: 0
[2026-03-10 21:46:14 TP0] Decode batch, #running-req: 1, #full token: 155, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.25, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 72.86, #queue-req: 0
[2026-03-10 21:46:15 TP0] Decode batch, #running-req: 1, #full token: 202, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.18, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 68.69, #queue-req: 0
[2026-03-10 21:46:15 TP0] Decode batch, #running-req: 1, #full token: 265, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.57, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 91.83, #queue-req: 0
[2026-03-10 21:46:16 TP0] Decode batch, #running-req: 1, #full token: 316, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 74.29, #queue-req: 0
[2026-03-10 21:46:17 TP0] Decode batch, #running-req: 1, #full token: 369, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.65, #queue-req: 0
[2026-03-10 21:46:17 TP0] Decode batch, #running-req: 1, #full token: 420, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 74.96, #queue-req: 0
[2026-03-10 21:46:18 TP0] Decode batch, #running-req: 1, #full token: 476, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.40, accept rate: 0.35, cuda graph: True, gen throughput (token/s): 82.42, #queue-req: 0
[2026-03-10 21:46:19 TP0] Decode batch, #running-req: 1, #full token: 526, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.25, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 73.69, #queue-req: 0
[2026-03-10 21:46:19 TP0] Decode batch, #running-req: 1, #full token: 576, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.25, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 73.28, #queue-req: 0
[2026-03-10 21:46:20 TP0] Decode batch, #running-req: 1, #full token: 629, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.71, #queue-req: 0
[2026-03-10 21:46:21 TP0] Decode batch, #running-req: 1, #full token: 679, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.25, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 73.53, #queue-req: 0
[2026-03-10 21:46:21 TP0] Decode batch, #running-req: 1, #full token: 730, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 75.00, #queue-req: 0
[2026-03-10 21:46:22 TP0] Decode batch, #running-req: 1, #full token: 781, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 74.99, #queue-req: 0
[2026-03-10 21:46:23 TP0] Decode batch, #running-req: 1, #full token: 829, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.20, accept rate: 0.30, cuda graph: True, gen throughput (token/s): 70.11, #queue-req: 0
[2026-03-10 21:46:23 TP0] Decode batch, #running-req: 1, #full token: 875, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.15, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 67.23, #queue-req: 0
[2026-03-10 21:46:24 TP0] Decode batch, #running-req: 1, #full token: 920, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.12, accept rate: 0.28, cuda graph: True, gen throughput (token/s): 65.38, #queue-req: 0
[2026-03-10 21:46:25 TP0] Decode batch, #running-req: 1, #full token: 977, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 83.49, #queue-req: 0
[2026-03-10 21:46:25 TP0] Decode batch, #running-req: 1, #full token: 1027, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.25, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 73.13, #queue-req: 0
[2026-03-10 21:46:26 TP0] Decode batch, #running-req: 1, #full token: 1078, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 74.60, #queue-req: 0
[2026-03-10 21:46:27 TP0] Decode batch, #running-req: 1, #full token: 1132, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.35, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 79.08, #queue-req: 0
[2026-03-10 21:46:27 TP0] Decode batch, #running-req: 1, #full token: 1185, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.05, #queue-req: 0
[2026-03-10 21:46:28 TP0] Decode batch, #running-req: 1, #full token: 1239, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.35, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 78.67, #queue-req: 0
[2026-03-10 21:46:29 TP0] Decode batch, #running-req: 1, #full token: 1292, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.22, #queue-req: 0
[2026-03-10 21:46:30 TP0] Decode batch, #running-req: 1, #full token: 1351, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.48, accept rate: 0.37, cuda graph: True, gen throughput (token/s): 86.08, #queue-req: 0
[2026-03-10 21:46:30 TP0] Decode batch, #running-req: 1, #full token: 1398, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.18, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 68.25, #queue-req: 0
[2026-03-10 21:46:31 TP0] Decode batch, #running-req: 1, #full token: 1451, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.29, #queue-req: 0
[2026-03-10 21:46:32 TP0] Decode batch, #running-req: 1, #full token: 1505, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.35, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 78.54, #queue-req: 0
[2026-03-10 21:46:32 TP0] Decode batch, #running-req: 1, #full token: 1558, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.06, #queue-req: 0
[2026-03-10 21:46:33 TP0] Decode batch, #running-req: 1, #full token: 1610, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.30, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 75.85, #queue-req: 0
[2026-03-10 21:46:34 TP0] Decode batch, #running-req: 1, #full token: 1665, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.38, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 80.22, #queue-req: 0
[2026-03-10 21:46:34 TP0] Decode batch, #running-req: 1, #full token: 1720, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.38, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 80.19, #queue-req: 0
[2026-03-10 21:46:35 TP0] Decode batch, #running-req: 1, #full token: 1777, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 82.98, #queue-req: 0
[2026-03-10 21:46:36 TP0] Decode batch, #running-req: 1, #full token: 1835, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.45, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 84.07, #queue-req: 0
[2026-03-10 21:46:36 TP0] Decode batch, #running-req: 1, #full token: 1892, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 82.42, #queue-req: 0
[2026-03-10 21:46:37 TP0] Decode batch, #running-req: 1, #full token: 1945, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 76.99, #queue-req: 0
[2026-03-10 21:46:38 TP0] Decode batch, #running-req: 1, #full token: 2000, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.38, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 80.22, #queue-req: 0
[2026-03-10 21:46:38 TP0] Decode batch, #running-req: 1, #full token: 2058, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.45, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 84.34, #queue-req: 0
[2026-03-10 21:46:39 TP0] Decode batch, #running-req: 1, #full token: 2121, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.57, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 91.83, #queue-req: 0
[2026-03-10 21:46:40 TP0] Decode batch, #running-req: 1, #full token: 2195, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.85, accept rate: 0.46, cuda graph: True, gen throughput (token/s): 108.03, #queue-req: 0
[2026-03-10 21:46:41 TP0] Decode batch, #running-req: 1, #full token: 2250, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.38, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 80.39, #queue-req: 0
[2026-03-10 21:46:41 TP0] Decode batch, #running-req: 1, #full token: 2307, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 83.04, #queue-req: 0
[2026-03-10 21:46:42 TP0] Decode batch, #running-req: 1, #full token: 2364, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 83.11, #queue-req: 0
[2026-03-10 21:46:43 TP0] Decode batch, #running-req: 1, #full token: 2421, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 83.12, #queue-req: 0
[2026-03-10 21:46:43 TP0] Decode batch, #running-req: 1, #full token: 2480, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.48, accept rate: 0.37, cuda graph: True, gen throughput (token/s): 85.65, #queue-req: 0
[2026-03-10 21:46:44 TP0] Decode batch, #running-req: 1, #full token: 2536, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.40, accept rate: 0.35, cuda graph: True, gen throughput (token/s): 81.49, #queue-req: 0
[2026-03-10 21:46:45 TP0] Decode batch, #running-req: 1, #full token: 2598, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.55, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 90.31, #queue-req: 0
[2026-03-10 21:46:45 TP0] Decode batch, #running-req: 1, #full token: 2656, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.45, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 84.53, #queue-req: 0
[2026-03-10 21:46:46 TP0] Decode batch, #running-req: 1, #full token: 2703, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.18, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 68.30, #queue-req: 0
[2026-03-10 21:46:47 TP0] Decode batch, #running-req: 1, #full token: 2760, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 82.77, #queue-req: 0
[2026-03-10 21:46:47 TP0] Decode batch, #running-req: 1, #full token: 2823, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.57, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 91.43, #queue-req: 0
[2026-03-10 21:46:48 TP0] Decode batch, #running-req: 1, #full token: 2876, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 77.05, #queue-req: 0
[2026-03-10 21:46:49 TP0] Decode batch, #running-req: 1, #full token: 2938, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.55, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 90.25, #queue-req: 0
[2026-03-10 21:46:49 TP0] Decode batch, #running-req: 1, #full token: 3011, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.82, accept rate: 0.46, cuda graph: True, gen throughput (token/s): 106.23, #queue-req: 0
[2026-03-10 21:46:50 TP0] Decode batch, #running-req: 1, #full token: 3065, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.35, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 78.40, #queue-req: 0
[2026-03-10 21:46:51 TP0] Decode batch, #running-req: 1, #full token: 3126, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.52, accept rate: 0.38, cuda graph: True, gen throughput (token/s): 88.57, #queue-req: 0
[2026-03-10 21:46:52 TP0] Decode batch, #running-req: 1, #full token: 3186, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.50, accept rate: 0.38, cuda graph: True, gen throughput (token/s): 87.23, #queue-req: 0
[2026-03-10 21:46:52 TP0] Decode batch, #running-req: 1, #full token: 3242, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.40, accept rate: 0.35, cuda graph: True, gen throughput (token/s): 81.46, #queue-req: 0
[2026-03-10 21:46:53 TP0] Decode batch, #running-req: 1, #full token: 3300, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.45, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 84.34, #queue-req: 0
[2026-03-10 21:46:54 TP0] Decode batch, #running-req: 1, #full token: 3363, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.57, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 91.45, #queue-req: 0
[2026-03-10 21:46:54 TP0] Decode batch, #running-req: 1, #full token: 3427, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.60, accept rate: 0.40, cuda graph: True, gen throughput (token/s): 92.92, #queue-req: 0
[2026-03-10 21:46:55 TP0] Decode batch, #running-req: 1, #full token: 3483, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.40, accept rate: 0.35, cuda graph: True, gen throughput (token/s): 81.29, #queue-req: 0
[2026-03-10 21:46:56 TP0] Decode batch, #running-req: 1, #full token: 3535, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.30, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 75.54, #queue-req: 0
[2026-03-10 21:46:56 TP0] Decode batch, #running-req: 1, #full token: 3595, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.50, accept rate: 0.38, cuda graph: True, gen throughput (token/s): 86.97, #queue-req: 0
[2026-03-10 21:46:57 TP0] Decode batch, #running-req: 1, #full token: 3652, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 82.59, #queue-req: 0
[2026-03-10 21:46:58 TP0] Decode batch, #running-req: 1, #full token: 3704, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.30, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 75.35, #queue-req: 0
[2026-03-10 21:46:58 TP0] Decode batch, #running-req: 1, #full token: 3755, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 73.80, #queue-req: 0
[2026-03-10 21:46:59 TP0] Decode batch, #running-req: 1, #full token: 3809, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.35, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 78.12, #queue-req: 0
[2026-03-10 21:47:00 TP0] Decode batch, #running-req: 1, #full token: 3872, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.57, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 91.09, #queue-req: 0
[2026-03-10 21:47:01 TP0] Decode batch, #running-req: 1, #full token: 3923, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 73.81, #queue-req: 0
[2026-03-10 21:47:01 TP0] Decode batch, #running-req: 1, #full token: 3975, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.30, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 75.39, #queue-req: 0
[2026-03-10 21:47:02 TP0] Decode batch, #running-req: 1, #full token: 4033, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.45, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 84.11, #queue-req: 0
[2026-03-10 21:47:03 TP0] Decode batch, #running-req: 1, #full token: 4079, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.15, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 66.71, #queue-req: 0
[2026-03-10 21:47:03 TP0] Decode batch, #running-req: 1, #full token: 4130, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 73.98, #queue-req: 0
[2026-03-10 21:47:04 TP0] Decode batch, #running-req: 1, #full token: 4189, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.48, accept rate: 0.37, cuda graph: True, gen throughput (token/s): 85.63, #queue-req: 0
[2026-03-10 21:47:05 TP0] Decode batch, #running-req: 1, #full token: 4242, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 76.99, #queue-req: 0
[2026-03-10 21:47:05 TP0] Decode batch, #running-req: 1, #full token: 4290, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.20, accept rate: 0.30, cuda graph: True, gen throughput (token/s): 69.65, #queue-req: 0
[2026-03-10 21:47:06 TP0] Decode batch, #running-req: 1, #full token: 4345, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.38, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 79.73, #queue-req: 0
[2026-03-10 21:47:07 TP0] Decode batch, #running-req: 1, #full token: 4404, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.48, accept rate: 0.37, cuda graph: True, gen throughput (token/s): 85.49, #queue-req: 0
[2026-03-10 21:47:07 TP0] Decode batch, #running-req: 1, #full token: 4465, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.52, accept rate: 0.38, cuda graph: True, gen throughput (token/s): 88.38, #queue-req: 0
[2026-03-10 21:47:08 TP0] Decode batch, #running-req: 1, #full token: 4518, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 76.85, #queue-req: 0
[2026-03-10 21:47:09 TP0] Decode batch, #running-req: 1, #full token: 4577, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.48, accept rate: 0.37, cuda graph: True, gen throughput (token/s): 85.67, #queue-req: 0
[2026-03-10 21:47:09 TP0] Decode batch, #running-req: 1, #full token: 4637, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.50, accept rate: 0.38, cuda graph: True, gen throughput (token/s): 87.04, #queue-req: 0
[2026-03-10 21:47:10 TP0] Decode batch, #running-req: 1, #full token: 4700, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.57, accept rate: 0.39, cuda graph: True, gen throughput (token/s): 91.51, #queue-req: 0
[2026-03-10 21:47:11 TP0] Decode batch, #running-req: 1, #full token: 4752, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.30, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 75.48, #queue-req: 0
[2026-03-10 21:47:12 TP0] Decode batch, #running-req: 1, #full token: 4809, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 82.76, #queue-req: 0
[2026-03-10 21:47:12 TP0] Decode batch, #running-req: 1, #full token: 4865, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.40, accept rate: 0.35, cuda graph: True, gen throughput (token/s): 81.30, #queue-req: 0
[2026-03-10 21:47:13 TP0] Decode batch, #running-req: 1, #full token: 4923, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.45, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 84.25, #queue-req: 0
[2026-03-10 21:47:14 TP0] Decode batch, #running-req: 1, #full token: 4974, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 74.10, #queue-req: 0
[2026-03-10 21:47:14 TP0] Decode batch, #running-req: 1, #full token: 5030, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.40, accept rate: 0.35, cuda graph: True, gen throughput (token/s): 81.36, #queue-req: 0
[2026-03-10 21:47:15 TP0] Decode batch, #running-req: 1, #full token: 5090, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.50, accept rate: 0.38, cuda graph: True, gen throughput (token/s): 87.15, #queue-req: 0
[2026-03-10 21:47:16 TP0] Decode batch, #running-req: 1, #full token: 5135, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.12, accept rate: 0.28, cuda graph: True, gen throughput (token/s): 65.39, #queue-req: 0
[2026-03-10 21:47:16 TP0] Decode batch, #running-req: 1, #full token: 5182, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.18, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 68.38, #queue-req: 0
[2026-03-10 21:47:17 TP0] Decode batch, #running-req: 1, #full token: 5233, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 74.16, #queue-req: 0
[2026-03-10 21:47:18 TP0] Decode batch, #running-req: 1, #full token: 5279, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.15, accept rate: 0.29, cuda graph: True, gen throughput (token/s): 66.94, #queue-req: 0
[2026-03-10 21:47:18 TP0] Decode batch, #running-req: 1, #full token: 5329, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.25, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 72.51, #queue-req: 0
[2026-03-10 21:47:19 TP0] Decode batch, #running-req: 1, #full token: 5386, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.43, accept rate: 0.36, cuda graph: True, gen throughput (token/s): 82.74, #queue-req: 0
[2026-03-10 21:47:20 TP0] Decode batch, #running-req: 1, #full token: 5437, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.27, accept rate: 0.32, cuda graph: True, gen throughput (token/s): 74.13, #queue-req: 0
[2026-03-10 21:47:20 TP0] Decode batch, #running-req: 1, #full token: 5486, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.23, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 71.03, #queue-req: 0
[2026-03-10 21:47:21 TP0] Decode batch, #running-req: 1, #full token: 5531, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.12, accept rate: 0.28, cuda graph: True, gen throughput (token/s): 65.24, #queue-req: 0
[2026-03-10 21:47:22 TP0] Decode batch, #running-req: 1, #full token: 5587, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.40, accept rate: 0.35, cuda graph: True, gen throughput (token/s): 81.22, #queue-req: 0
[2026-03-10 21:47:23 TP0] Decode batch, #running-req: 1, #full token: 5637, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.25, accept rate: 0.31, cuda graph: True, gen throughput (token/s): 72.49, #queue-req: 0
[2026-03-10 21:47:23 TP0] Decode batch, #running-req: 1, #full token: 5690, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.32, accept rate: 0.33, cuda graph: True, gen throughput (token/s): 76.77, #queue-req: 0
[2026-03-10 21:47:24 TP0] Decode batch, #running-req: 1, #full token: 5744, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.35, accept rate: 0.34, cuda graph: True, gen throughput (token/s): 78.26, #queue-req: 0

sglang with speculative disabled 4*3090

Details
[2026-03-10 21:48:59 TP0] Prefill batch, #new-seq: 1, #new-token: 19, #cached-token: 0, full token usage: 0.00, mamba usage: 0.00, #running-req: 0, #queue-req: 0, input throughput (token/s): 14.96, cuda graph: False
[2026-03-10 21:49:00 TP0] Decode batch, #running-req: 1, #full token: 52, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 2.41, #queue-req: 0
[2026-03-10 21:49:00 TP0] Decode batch, #running-req: 1, #full token: 92, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.43, #queue-req: 0
[2026-03-10 21:49:00 TP0] Decode batch, #running-req: 1, #full token: 132, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.18, #queue-req: 0
[2026-03-10 21:49:01 TP0] Decode batch, #running-req: 1, #full token: 172, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.02, #queue-req: 0
[2026-03-10 21:49:01 TP0] Decode batch, #running-req: 1, #full token: 212, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.33, #queue-req: 0
[2026-03-10 21:49:01 TP0] Decode batch, #running-req: 1, #full token: 252, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 116.86, #queue-req: 0
[2026-03-10 21:49:02 TP0] Decode batch, #running-req: 1, #full token: 292, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.61, #queue-req: 0
[2026-03-10 21:49:02 TP0] Decode batch, #running-req: 1, #full token: 332, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 116.86, #queue-req: 0
[2026-03-10 21:49:02 TP0] Decode batch, #running-req: 1, #full token: 372, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.09, #queue-req: 0
[2026-03-10 21:49:03 TP0] Decode batch, #running-req: 1, #full token: 412, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.63, #queue-req: 0
[2026-03-10 21:49:03 TP0] Decode batch, #running-req: 1, #full token: 452, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.14, #queue-req: 0
[2026-03-10 21:49:03 TP0] Decode batch, #running-req: 1, #full token: 492, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.63, #queue-req: 0
[2026-03-10 21:49:04 TP0] Decode batch, #running-req: 1, #full token: 532, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.64, #queue-req: 0
[2026-03-10 21:49:04 TP0] Decode batch, #running-req: 1, #full token: 572, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.29, #queue-req: 0
[2026-03-10 21:49:04 TP0] Decode batch, #running-req: 1, #full token: 612, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.63, #queue-req: 0
[2026-03-10 21:49:05 TP0] Decode batch, #running-req: 1, #full token: 652, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.47, #queue-req: 0
[2026-03-10 21:49:05 TP0] Decode batch, #running-req: 1, #full token: 692, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.40, #queue-req: 0
[2026-03-10 21:49:05 TP0] Decode batch, #running-req: 1, #full token: 732, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 115.37, #queue-req: 0
[2026-03-10 21:49:06 TP0] Decode batch, #running-req: 1, #full token: 772, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 116.97, #queue-req: 0
[2026-03-10 21:49:06 TP0] Decode batch, #running-req: 1, #full token: 812, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.61, #queue-req: 0
[2026-03-10 21:49:06 TP0] Decode batch, #running-req: 1, #full token: 852, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.55, #queue-req: 0
[2026-03-10 21:49:07 TP0] Decode batch, #running-req: 1, #full token: 892, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.58, #queue-req: 0
[2026-03-10 21:49:07 TP0] Decode batch, #running-req: 1, #full token: 932, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 116.02, #queue-req: 0
[2026-03-10 21:49:07 TP0] Decode batch, #running-req: 1, #full token: 972, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 116.55, #queue-req: 0
[2026-03-10 21:49:08 TP0] Decode batch, #running-req: 1, #full token: 1012, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.56, #queue-req: 0
[2026-03-10 21:49:08 TP0] Decode batch, #running-req: 1, #full token: 1052, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.17, #queue-req: 0
[2026-03-10 21:49:08 TP0] Decode batch, #running-req: 1, #full token: 1092, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.06, #queue-req: 0
[2026-03-10 21:49:09 TP0] Decode batch, #running-req: 1, #full token: 1132, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.20, #queue-req: 0
[2026-03-10 21:49:09 TP0] Decode batch, #running-req: 1, #full token: 1172, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.64, #queue-req: 0
[2026-03-10 21:49:09 TP0] Decode batch, #running-req: 1, #full token: 1212, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.15, #queue-req: 0
[2026-03-10 21:49:10 TP0] Decode batch, #running-req: 1, #full token: 1252, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.80, #queue-req: 0
[2026-03-10 21:49:10 TP0] Decode batch, #running-req: 1, #full token: 1292, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 116.66, #queue-req: 0
[2026-03-10 21:49:10 TP0] Decode batch, #running-req: 1, #full token: 1332, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.63, #queue-req: 0
[2026-03-10 21:49:11 TP0] Decode batch, #running-req: 1, #full token: 1372, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.37, #queue-req: 0
[2026-03-10 21:49:11 TP0] Decode batch, #running-req: 1, #full token: 1412, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.48, #queue-req: 0
[2026-03-10 21:49:11 TP0] Decode batch, #running-req: 1, #full token: 1452, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.20, #queue-req: 0
[2026-03-10 21:49:12 TP0] Decode batch, #running-req: 1, #full token: 1492, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 115.80, #queue-req: 0
[2026-03-10 21:49:12 TP0] Decode batch, #running-req: 1, #full token: 1532, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.10, #queue-req: 0
[2026-03-10 21:49:13 TP0] Decode batch, #running-req: 1, #full token: 1572, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.22, #queue-req: 0
[2026-03-10 21:49:13 TP0] Decode batch, #running-req: 1, #full token: 1612, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.48, #queue-req: 0
[2026-03-10 21:49:13 TP0] Decode batch, #running-req: 1, #full token: 1652, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 116.79, #queue-req: 0
[2026-03-10 21:49:14 TP0] Decode batch, #running-req: 1, #full token: 1692, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.43, #queue-req: 0
[2026-03-10 21:49:14 TP0] Decode batch, #running-req: 1, #full token: 1732, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.64, #queue-req: 0
[2026-03-10 21:49:14 TP0] Decode batch, #running-req: 1, #full token: 1772, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.21, #queue-req: 0
[2026-03-10 21:49:15 TP0] Decode batch, #running-req: 1, #full token: 1812, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.42, #queue-req: 0
[2026-03-10 21:49:15 TP0] Decode batch, #running-req: 1, #full token: 1852, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.67, #queue-req: 0
[2026-03-10 21:49:15 TP0] Decode batch, #running-req: 1, #full token: 1892, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.53, #queue-req: 0
[2026-03-10 21:49:16 TP0] Decode batch, #running-req: 1, #full token: 1932, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.80, #queue-req: 0
[2026-03-10 21:49:16 TP0] Decode batch, #running-req: 1, #full token: 1972, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.45, #queue-req: 0
[2026-03-10 21:49:16 TP0] Decode batch, #running-req: 1, #full token: 2012, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.80, #queue-req: 0
[2026-03-10 21:49:17 TP0] Decode batch, #running-req: 1, #full token: 2052, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.70, #queue-req: 0
[2026-03-10 21:49:17 TP0] Decode batch, #running-req: 1, #full token: 2092, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.75, #queue-req: 0
[2026-03-10 21:49:17 TP0] Decode batch, #running-req: 1, #full token: 2132, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.09, #queue-req: 0
[2026-03-10 21:49:18 TP0] Decode batch, #running-req: 1, #full token: 2172, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.88, #queue-req: 0
[2026-03-10 21:49:18 TP0] Decode batch, #running-req: 1, #full token: 2212, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.80, #queue-req: 0
[2026-03-10 21:49:18 TP0] Decode batch, #running-req: 1, #full token: 2252, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.71, #queue-req: 0
[2026-03-10 21:49:19 TP0] Decode batch, #running-req: 1, #full token: 2292, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.22, #queue-req: 0
[2026-03-10 21:49:19 TP0] Decode batch, #running-req: 1, #full token: 2332, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.69, #queue-req: 0
[2026-03-10 21:49:19 TP0] Decode batch, #running-req: 1, #full token: 2372, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.76, #queue-req: 0
[2026-03-10 21:49:20 TP0] Decode batch, #running-req: 1, #full token: 2412, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.14, #queue-req: 0
[2026-03-10 21:49:20 TP0] Decode batch, #running-req: 1, #full token: 2452, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.76, #queue-req: 0
[2026-03-10 21:49:20 TP0] Decode batch, #running-req: 1, #full token: 2492, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.64, #queue-req: 0
[2026-03-10 21:49:21 TP0] Decode batch, #running-req: 1, #full token: 2532, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.32, #queue-req: 0
[2026-03-10 21:49:21 TP0] Decode batch, #running-req: 1, #full token: 2572, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.99, #queue-req: 0
[2026-03-10 21:49:21 TP0] Decode batch, #running-req: 1, #full token: 2612, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.18, #queue-req: 0
[2026-03-10 21:49:22 TP0] Decode batch, #running-req: 1, #full token: 2652, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.98, #queue-req: 0
[2026-03-10 21:49:22 TP0] Decode batch, #running-req: 1, #full token: 2692, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.33, #queue-req: 0
[2026-03-10 21:49:22 TP0] Decode batch, #running-req: 1, #full token: 2732, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.45, #queue-req: 0
[2026-03-10 21:49:23 TP0] Decode batch, #running-req: 1, #full token: 2772, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.52, #queue-req: 0
[2026-03-10 21:49:23 TP0] Decode batch, #running-req: 1, #full token: 2812, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.64, #queue-req: 0
[2026-03-10 21:49:23 TP0] Decode batch, #running-req: 1, #full token: 2852, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.97, #queue-req: 0
[2026-03-10 21:49:24 TP0] Decode batch, #running-req: 1, #full token: 2892, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.55, #queue-req: 0
[2026-03-10 21:49:24 TP0] Decode batch, #running-req: 1, #full token: 2932, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.57, #queue-req: 0
[2026-03-10 21:49:24 TP0] Decode batch, #running-req: 1, #full token: 2972, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.73, #queue-req: 0
[2026-03-10 21:49:25 TP0] Decode batch, #running-req: 1, #full token: 3012, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.61, #queue-req: 0
[2026-03-10 21:49:25 TP0] Decode batch, #running-req: 1, #full token: 3052, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.55, #queue-req: 0
[2026-03-10 21:49:25 TP0] Decode batch, #running-req: 1, #full token: 3092, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.46, #queue-req: 0
[2026-03-10 21:49:26 TP0] Decode batch, #running-req: 1, #full token: 3132, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.94, #queue-req: 0
[2026-03-10 21:49:26 TP0] Decode batch, #running-req: 1, #full token: 3172, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.03, #queue-req: 0
[2026-03-10 21:49:26 TP0] Decode batch, #running-req: 1, #full token: 3212, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.67, #queue-req: 0
[2026-03-10 21:49:27 TP0] Decode batch, #running-req: 1, #full token: 3252, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.75, #queue-req: 0
[2026-03-10 21:49:27 TP0] Decode batch, #running-req: 1, #full token: 3292, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.68, #queue-req: 0
[2026-03-10 21:49:27 TP0] Decode batch, #running-req: 1, #full token: 3332, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.19, #queue-req: 0
[2026-03-10 21:49:28 TP0] Decode batch, #running-req: 1, #full token: 3372, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.86, #queue-req: 0
[2026-03-10 21:49:28 TP0] Decode batch, #running-req: 1, #full token: 3412, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.89, #queue-req: 0
[2026-03-10 21:49:28 TP0] Decode batch, #running-req: 1, #full token: 3452, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.65, #queue-req: 0
[2026-03-10 21:49:29 TP0] Decode batch, #running-req: 1, #full token: 3492, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.59, #queue-req: 0
[2026-03-10 21:49:29 TP0] Decode batch, #running-req: 1, #full token: 3532, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.62, #queue-req: 0
[2026-03-10 21:49:29 TP0] Decode batch, #running-req: 1, #full token: 3572, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.99, #queue-req: 0
[2026-03-10 21:49:30 TP0] Decode batch, #running-req: 1, #full token: 3612, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.65, #queue-req: 0
[2026-03-10 21:49:30 TP0] Decode batch, #running-req: 1, #full token: 3652, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.14, #queue-req: 0
[2026-03-10 21:49:30 TP0] Decode batch, #running-req: 1, #full token: 3692, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.90, #queue-req: 0
[2026-03-10 21:49:31 TP0] Decode batch, #running-req: 1, #full token: 3732, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.94, #queue-req: 0
[2026-03-10 21:49:31 TP0] Decode batch, #running-req: 1, #full token: 3772, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.73, #queue-req: 0
[2026-03-10 21:49:32 TP0] Decode batch, #running-req: 1, #full token: 3812, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.70, #queue-req: 0
[2026-03-10 21:49:32 TP0] Decode batch, #running-req: 1, #full token: 3852, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.90, #queue-req: 0
[2026-03-10 21:49:32 TP0] Decode batch, #running-req: 1, #full token: 3892, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.99, #queue-req: 0
[2026-03-10 21:49:33 TP0] Decode batch, #running-req: 1, #full token: 3932, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.97, #queue-req: 0
[2026-03-10 21:49:33 TP0] Decode batch, #running-req: 1, #full token: 3972, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.97, #queue-req: 0
[2026-03-10 21:49:33 TP0] Decode batch, #running-req: 1, #full token: 4012, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.90, #queue-req: 0
[2026-03-10 21:49:34 TP0] Decode batch, #running-req: 1, #full token: 4052, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.98, #queue-req: 0
[2026-03-10 21:49:34 TP0] Decode batch, #running-req: 1, #full token: 4092, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.69, #queue-req: 0
[2026-03-10 21:49:34 TP0] Decode batch, #running-req: 1, #full token: 4132, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.93, #queue-req: 0
[2026-03-10 21:49:35 TP0] Decode batch, #running-req: 1, #full token: 4172, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.81, #queue-req: 0
[2026-03-10 21:49:35 TP0] Decode batch, #running-req: 1, #full token: 4212, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 117.65, #queue-req: 0
[2026-03-10 21:49:35 TP0] Decode batch, #running-req: 1, #full token: 4252, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.07, #queue-req: 0
[2026-03-10 21:49:36 TP0] Decode batch, #running-req: 1, #full token: 4292, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.03, #queue-req: 0
[2026-03-10 21:49:36 TP0] Decode batch, #running-req: 1, #full token: 4332, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, cuda graph: True, gen throughput (token/s): 118.16, #queue-req: 0

Qwen3.5-9B BF16 performance for 2*3090

Details

main: n_kv_max = 65536, n_batch = 2048, n_ubatch = 512, flash_attn = 1, n_gpu_layers = 999, n_threads = 8, n_threads_batch = 8

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.167 3074.28 1.883 67.97
512 128 512 0.159 3218.02 1.913 66.92
512 128 1024 0.129 3957.73 1.893 67.60
512 128 1536 0.130 3942.28 1.898 67.44
512 128 2048 0.131 3898.25 1.904 67.23
512 128 2560 0.131 3900.00 1.908 67.07
512 128 3072 0.132 3886.59 1.912 66.93
512 128 3584 0.132 3877.64 1.916 66.79
512 128 4096 0.132 3873.15 1.920 66.67
512 128 4608 0.133 3859.64 1.925 66.50
512 128 5120 0.133 3842.78 1.927 66.42
512 128 5632 0.133 3841.77 1.930 66.31
512 128 6144 0.134 3809.38 1.936 66.11
512 128 6656 0.134 3815.29 1.941 65.96
512 128 7168 0.135 3803.13 1.946 65.77
512 128 7680 0.135 3797.83 1.952 65.58
512 128 8192 0.135 3783.60 1.956 65.43
512 128 8704 0.135 3778.90 1.964 65.17
512 128 9216 0.136 3770.72 1.967 65.07
512 128 9728 0.136 3761.80 1.972 64.90
512 128 10240 0.137 3741.54 1.979 64.68
512 128 10752 0.137 3731.53 1.995 64.16
512 128 11264 0.137 3733.08 1.999 64.05
512 128 11776 0.137 3747.59 2.002 63.94
512 128 12288 0.137 3725.94 2.006 63.80
512 128 12800 0.137 3725.05 2.008 63.74
512 128 13312 0.138 3716.53 2.013 63.59
512 128 13824 0.138 3711.22 2.016 63.49

sglang with speculative enabled 2*3090

Details
[2026-03-10 21:38:54 TP0] Prefill batch, #new-seq: 1, #new-token: 19, #cached-token: 0, full token usage: 0.00, mamba usage: 1.00, #running-req: 0, #queue-req: 0, input throughput (token/s): 5.25, cuda graph: False
[2026-03-10 21:38:55 TP0] Decode batch, #running-req: 1, #full token: 76, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.05, accept rate: 0.51, cuda graph: True, gen throughput (token/s): 30.10, #queue-req: 0
[2026-03-10 21:38:56 TP0] Decode batch, #running-req: 1, #full token: 164, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.20, accept rate: 0.55, cuda graph: True, gen throughput (token/s): 98.23, #queue-req: 0
[2026-03-10 21:38:57 TP0] Decode batch, #running-req: 1, #full token: 247, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.08, accept rate: 0.52, cuda graph: True, gen throughput (token/s): 92.57, #queue-req: 0
[2026-03-10 21:38:57 TP0] Decode batch, #running-req: 1, #full token: 340, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.33, accept rate: 0.58, cuda graph: True, gen throughput (token/s): 103.04, #queue-req: 0
[2026-03-10 21:38:58 TP0] Decode batch, #running-req: 1, #full token: 410, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.75, accept rate: 0.44, cuda graph: True, gen throughput (token/s): 78.03, #queue-req: 0
[2026-03-10 21:38:59 TP0] Decode batch, #running-req: 1, #full token: 486, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.90, accept rate: 0.47, cuda graph: True, gen throughput (token/s): 84.22, #queue-req: 0
[2026-03-10 21:39:00 TP0] Decode batch, #running-req: 1, #full token: 566, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.00, accept rate: 0.50, cuda graph: True, gen throughput (token/s): 89.18, #queue-req: 0
[2026-03-10 21:39:01 TP0] Decode batch, #running-req: 1, #full token: 655, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.23, accept rate: 0.56, cuda graph: True, gen throughput (token/s): 99.08, #queue-req: 0
[2026-03-10 21:39:02 TP0] Decode batch, #running-req: 1, #full token: 740, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.12, accept rate: 0.53, cuda graph: True, gen throughput (token/s): 94.34, #queue-req: 0
[2026-03-10 21:39:03 TP0] Decode batch, #running-req: 1, #full token: 825, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.12, accept rate: 0.53, cuda graph: True, gen throughput (token/s): 94.51, #queue-req: 0
[2026-03-10 21:39:04 TP0] Decode batch, #running-req: 1, #full token: 943, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.95, accept rate: 0.74, cuda graph: True, gen throughput (token/s): 130.69, #queue-req: 0
[2026-03-10 21:39:05 TP0] Decode batch, #running-req: 1, #full token: 1045, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.55, accept rate: 0.64, cuda graph: True, gen throughput (token/s): 113.25, #queue-req: 0
[2026-03-10 21:39:06 TP0] Decode batch, #running-req: 1, #full token: 1137, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.30, accept rate: 0.57, cuda graph: True, gen throughput (token/s): 102.13, #queue-req: 0
[2026-03-10 21:39:06 TP0] Decode batch, #running-req: 1, #full token: 1217, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.00, accept rate: 0.50, cuda graph: True, gen throughput (token/s): 89.00, #queue-req: 0
[2026-03-10 21:39:07 TP0] Decode batch, #running-req: 1, #full token: 1290, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.82, accept rate: 0.46, cuda graph: True, gen throughput (token/s): 81.08, #queue-req: 0
[2026-03-10 21:39:08 TP0] Decode batch, #running-req: 1, #full token: 1364, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.85, accept rate: 0.46, cuda graph: True, gen throughput (token/s): 81.98, #queue-req: 0
[2026-03-10 21:39:09 TP0] Decode batch, #running-req: 1, #full token: 1435, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.77, accept rate: 0.44, cuda graph: True, gen throughput (token/s): 78.79, #queue-req: 0
[2026-03-10 21:39:10 TP0] Decode batch, #running-req: 1, #full token: 1522, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.17, accept rate: 0.54, cuda graph: True, gen throughput (token/s): 96.62, #queue-req: 0
[2026-03-10 21:39:11 TP0] Decode batch, #running-req: 1, #full token: 1601, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 1.98, accept rate: 0.49, cuda graph: True, gen throughput (token/s): 87.62, #queue-req: 0
[2026-03-10 21:39:12 TP0] Decode batch, #running-req: 1, #full token: 1683, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.05, accept rate: 0.51, cuda graph: True, gen throughput (token/s): 90.99, #queue-req: 0
[2026-03-10 21:39:13 TP0] Decode batch, #running-req: 1, #full token: 1779, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.40, accept rate: 0.60, cuda graph: True, gen throughput (token/s): 106.50, #queue-req: 0
[2026-03-10 21:39:14 TP0] Decode batch, #running-req: 1, #full token: 1873, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.35, accept rate: 0.59, cuda graph: True, gen throughput (token/s): 104.29, #queue-req: 0
[2026-03-10 21:39:15 TP0] Decode batch, #running-req: 1, #full token: 1964, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.27, accept rate: 0.57, cuda graph: True, gen throughput (token/s): 100.98, #queue-req: 0
[2026-03-10 21:39:15 TP0] Decode batch, #running-req: 1, #full token: 2064, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.50, accept rate: 0.62, cuda graph: True, gen throughput (token/s): 110.88, #queue-req: 0
[2026-03-10 21:39:16 TP0] Decode batch, #running-req: 1, #full token: 2160, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.40, accept rate: 0.60, cuda graph: True, gen throughput (token/s): 106.64, #queue-req: 0
[2026-03-10 21:39:17 TP0] Decode batch, #running-req: 1, #full token: 2264, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.60, accept rate: 0.65, cuda graph: True, gen throughput (token/s): 115.45, #queue-req: 0
[2026-03-10 21:39:18 TP0] Decode batch, #running-req: 1, #full token: 2372, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.70, accept rate: 0.68, cuda graph: True, gen throughput (token/s): 119.94, #queue-req: 0
[2026-03-10 21:39:19 TP0] Decode batch, #running-req: 1, #full token: 2474, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.55, accept rate: 0.64, cuda graph: True, gen throughput (token/s): 113.19, #queue-req: 0
[2026-03-10 21:39:20 TP0] Decode batch, #running-req: 1, #full token: 2580, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.65, accept rate: 0.66, cuda graph: True, gen throughput (token/s): 117.61, #queue-req: 0
[2026-03-10 21:39:21 TP0] Decode batch, #running-req: 1, #full token: 2689, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.73, accept rate: 0.68, cuda graph: True, gen throughput (token/s): 120.93, #queue-req: 0
[2026-03-10 21:39:22 TP0] Decode batch, #running-req: 1, #full token: 2788, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.48, accept rate: 0.62, cuda graph: True, gen throughput (token/s): 109.75, #queue-req: 0
[2026-03-10 21:39:23 TP0] Decode batch, #running-req: 1, #full token: 2886, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.45, accept rate: 0.61, cuda graph: True, gen throughput (token/s): 108.67, #queue-req: 0
[2026-03-10 21:39:24 TP0] Decode batch, #running-req: 1, #full token: 2996, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.75, accept rate: 0.69, cuda graph: True, gen throughput (token/s): 121.87, #queue-req: 0
[2026-03-10 21:39:24 TP0] Decode batch, #running-req: 1, #full token: 3097, full token usage: 0.00, mamba num: 1, mamba usage: 1.00, accept len: 2.52, accept rate: 0.63, cuda graph: True, gen throughput (token/s): 111.92, #queue-req: 0
[2026-03-10 21:39:25 TP0] Decode batch, #running-req: 1, #full token: 3206, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.73, accept rate: 0.68, cuda graph: True, gen throughput (token/s): 120.94, #queue-req: 0
[2026-03-10 21:39:26 TP0] Decode batch, #running-req: 1, #full token: 3310, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.60, accept rate: 0.65, cuda graph: True, gen throughput (token/s): 115.18, #queue-req: 0
[2026-03-10 21:39:27 TP0] Decode batch, #running-req: 1, #full token: 3421, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.77, accept rate: 0.69, cuda graph: True, gen throughput (token/s): 122.63, #queue-req: 0
[2026-03-10 21:39:28 TP0] Decode batch, #running-req: 1, #full token: 3531, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.75, accept rate: 0.69, cuda graph: True, gen throughput (token/s): 121.68, #queue-req: 0
[2026-03-10 21:39:29 TP0] Decode batch, #running-req: 1, #full token: 3635, full token usage: 0.01, mamba num: 1, mamba usage: 1.00, accept len: 2.60, accept rate: 0.65, cuda graph: True, gen throughput (token/s): 115.01, #queue-req: 0

sglang with speculative disabled 2*3090

Details
[2026-03-10 21:51:48 TP0] Prefill batch, #new-seq: 1, #new-token: 19, #cached-token: 0, full token usage: 0.00, mamba usage: 0.01, #running-req: 0, #queue-req: 0, input throughput (token/s): 4.12, cuda graph: False
[2026-03-10 21:51:48 TP0] Decode batch, #running-req: 1, #full token: 41, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 1.77, #queue-req: 0
[2026-03-10 21:51:49 TP0] Decode batch, #running-req: 1, #full token: 81, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.17, #queue-req: 0
[2026-03-10 21:51:49 TP0] Decode batch, #running-req: 1, #full token: 121, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.99, #queue-req: 0
[2026-03-10 21:51:50 TP0] Decode batch, #running-req: 1, #full token: 161, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.30, #queue-req: 0
[2026-03-10 21:51:50 TP0] Decode batch, #running-req: 1, #full token: 201, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.31, #queue-req: 0
[2026-03-10 21:51:51 TP0] Decode batch, #running-req: 1, #full token: 241, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.26, #queue-req: 0
[2026-03-10 21:51:51 TP0] Decode batch, #running-req: 1, #full token: 281, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.00, #queue-req: 0
[2026-03-10 21:51:52 TP0] Decode batch, #running-req: 1, #full token: 321, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.08, #queue-req: 0
[2026-03-10 21:51:52 TP0] Decode batch, #running-req: 1, #full token: 361, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.76, #queue-req: 0
[2026-03-10 21:51:53 TP0] Decode batch, #running-req: 1, #full token: 401, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.32, #queue-req: 0
[2026-03-10 21:51:53 TP0] Decode batch, #running-req: 1, #full token: 441, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.66, #queue-req: 0
[2026-03-10 21:51:54 TP0] Decode batch, #running-req: 1, #full token: 481, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.94, #queue-req: 0
[2026-03-10 21:51:54 TP0] Decode batch, #running-req: 1, #full token: 521, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.68, #queue-req: 0
[2026-03-10 21:51:55 TP0] Decode batch, #running-req: 1, #full token: 561, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.92, #queue-req: 0
[2026-03-10 21:51:55 TP0] Decode batch, #running-req: 1, #full token: 601, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.15, #queue-req: 0
[2026-03-10 21:51:56 TP0] Decode batch, #running-req: 1, #full token: 641, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.07, #queue-req: 0
[2026-03-10 21:51:57 TP0] Decode batch, #running-req: 1, #full token: 681, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.11, #queue-req: 0
[2026-03-10 21:51:57 TP0] Decode batch, #running-req: 1, #full token: 721, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.26, #queue-req: 0
[2026-03-10 21:51:58 TP0] Decode batch, #running-req: 1, #full token: 761, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.25, #queue-req: 0
[2026-03-10 21:51:58 TP0] Decode batch, #running-req: 1, #full token: 801, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.16, #queue-req: 0
[2026-03-10 21:51:59 TP0] Decode batch, #running-req: 1, #full token: 841, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.48, #queue-req: 0
[2026-03-10 21:51:59 TP0] Decode batch, #running-req: 1, #full token: 881, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.99, #queue-req: 0
[2026-03-10 21:52:00 TP0] Decode batch, #running-req: 1, #full token: 921, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.02, #queue-req: 0
[2026-03-10 21:52:00 TP0] Decode batch, #running-req: 1, #full token: 961, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.97, #queue-req: 0
[2026-03-10 21:52:01 TP0] Decode batch, #running-req: 1, #full token: 1001, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 74.76, #queue-req: 0
[2026-03-10 21:52:01 TP0] Decode batch, #running-req: 1, #full token: 1041, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.09, #queue-req: 0
[2026-03-10 21:52:02 TP0] Decode batch, #running-req: 1, #full token: 1081, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.37, #queue-req: 0
[2026-03-10 21:52:02 TP0] Decode batch, #running-req: 1, #full token: 1121, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.53, #queue-req: 0
[2026-03-10 21:52:03 TP0] Decode batch, #running-req: 1, #full token: 1161, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.51, #queue-req: 0
[2026-03-10 21:52:03 TP0] Decode batch, #running-req: 1, #full token: 1201, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.37, #queue-req: 0
[2026-03-10 21:52:04 TP0] Decode batch, #running-req: 1, #full token: 1241, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.32, #queue-req: 0
[2026-03-10 21:52:05 TP0] Decode batch, #running-req: 1, #full token: 1281, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.15, #queue-req: 0
[2026-03-10 21:52:05 TP0] Decode batch, #running-req: 1, #full token: 1321, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.27, #queue-req: 0
[2026-03-10 21:52:06 TP0] Decode batch, #running-req: 1, #full token: 1361, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.54, #queue-req: 0
[2026-03-10 21:52:06 TP0] Decode batch, #running-req: 1, #full token: 1401, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.60, #queue-req: 0
[2026-03-10 21:52:07 TP0] Decode batch, #running-req: 1, #full token: 1441, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.46, #queue-req: 0
[2026-03-10 21:52:07 TP0] Decode batch, #running-req: 1, #full token: 1481, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.33, #queue-req: 0
[2026-03-10 21:52:08 TP0] Decode batch, #running-req: 1, #full token: 1521, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.66, #queue-req: 0
[2026-03-10 21:52:08 TP0] Decode batch, #running-req: 1, #full token: 1561, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.47, #queue-req: 0
[2026-03-10 21:52:09 TP0] Decode batch, #running-req: 1, #full token: 1601, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.70, #queue-req: 0
[2026-03-10 21:52:09 TP0] Decode batch, #running-req: 1, #full token: 1641, full token usage: 0.00, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.39, #queue-req: 0
[2026-03-10 21:52:10 TP0] Decode batch, #running-req: 1, #full token: 1681, full token usage: 0.01, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.34, #queue-req: 0
[2026-03-10 21:52:10 TP0] Decode batch, #running-req: 1, #full token: 1721, full token usage: 0.01, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.47, #queue-req: 0
[2026-03-10 21:52:11 TP0] Decode batch, #running-req: 1, #full token: 1761, full token usage: 0.01, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.18, #queue-req: 0
[2026-03-10 21:52:11 TP0] Decode batch, #running-req: 1, #full token: 1801, full token usage: 0.01, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.41, #queue-req: 0
[2026-03-10 21:52:12 TP0] Decode batch, #running-req: 1, #full token: 1841, full token usage: 0.01, mamba num: 2, mamba usage: 0.01, cuda graph: True, gen throughput (token/s): 75.16, #queue-req: 0

Qwen3.5-9B BF16 performance for 4*3090 -sm layer

Details

main: n_kv_max = 65536, n_batch = 2048, n_ubatch = 512, flash_attn = 1, n_gpu_layers = 999, n_threads = 8, n_threads_batch = 8

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.193 2658.58 3.022 42.36
512 128 512 0.183 2800.63 3.036 42.15
512 128 1024 0.181 2826.22 3.041 42.09
512 128 1536 0.181 2834.45 3.047 42.01
512 128 2048 0.177 2887.78 3.054 41.91
512 128 2560 0.178 2878.31 3.062 41.80
512 128 3072 0.179 2867.45 3.071 41.68
512 128 3584 0.179 2856.33 3.080 41.56
512 128 4096 0.180 2844.52 3.089 41.44
512 128 4608 0.181 2834.41 3.098 41.31
512 128 5120 0.181 2824.14 3.118 41.05
512 128 5632 0.182 2814.02 3.129 40.91
512 128 6144 0.183 2801.27 3.135 40.82

Qwen3.5-9B BF16 performance for 2*3090 -sm layer

Details

main: n_kv_max = 65536, n_batch = 2048, n_ubatch = 512, flash_attn = 1, n_gpu_layers = 999, n_threads = 8, n_threads_batch = 8

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.173 2956.05 2.923 43.79
512 128 512 0.165 3105.12 2.931 43.67
512 128 1024 0.164 3116.23 2.935 43.61
512 128 1536 0.162 3156.50 2.941 43.52
512 128 2048 0.160 3194.97 2.949 43.41
512 128 2560 0.161 3179.53 2.956 43.30
512 128 3072 0.162 3169.00 2.964 43.18
512 128 3584 0.162 3160.65 2.974 43.04
512 128 4096 0.163 3141.22 2.982 42.92
512 128 4608 0.163 3146.16 2.990 42.81
512 128 5120 0.164 3121.23 3.008 42.56
512 128 5632 0.165 3102.69 3.017 42.43
512 128 6144 0.165 3103.35 3.023 42.34
512 128 6656 0.166 3087.28 3.031 42.23
512 128 7168 0.165 3096.59 3.036 42.16
512 128 7680 0.166 3077.72 3.041 42.09
512 128 8192 0.167 3064.53 3.049 41.98
512 128 8704 0.168 3046.28 3.057 41.88
512 128 9216 0.168 3042.71 3.065 41.76
512 128 9728 0.169 3028.49 3.074 41.64
512 128 10240 0.169 3026.31 3.081 41.55
512 128 10752 0.170 3004.71 3.101 41.28

Qwen3.5-9B BF16 performance for 1*3090

Details

main: n_kv_max = 65536, n_batch = 2048, n_ubatch = 512, flash_attn = 1, n_gpu_layers = 999, n_threads = 8, n_threads_batch = 8

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
512 128 0 0.172 2982.45 2.936 43.59
512 128 512 0.165 3104.93 2.944 43.48
512 128 1024 0.166 3090.78 2.949 43.41
512 128 1536 0.161 3189.97 2.960 43.24
512 128 2048 0.161 3174.13 2.966 43.15
512 128 2560 0.162 3161.35 2.975 43.03
512 128 3072 0.162 3151.47 2.983 42.91
512 128 3584 0.163 3138.58 2.993 42.77

Note: sglang is still outputing 11111111......, so the performance may not true.

Nexesenex added a commit to Nexesenex/ik_llama.cpp.nxs that referenced this pull request Mar 10, 2026
@Ph0rk0z
Copy link

Ph0rk0z commented Mar 10, 2026

What does it have to do with sillytavern? I ran sweep benches before/after with same command line as previously. I am still getting same t/s for devstral Q4_K and somehow it's higher than you at 37t/s. If it's not Q8 cache for qwen then I don't know but that's how it crashes on me and seems to have plenty of memory.

@Ph0rk0z
Copy link

Ph0rk0z commented Mar 10, 2026

Main branch:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 2.854 358.85 6.772 37.81
1024 256 1024 2.709 377.95 6.934 36.92
1024 256 2048 2.726 375.60 7.007 36.54
1024 256 3072 2.745 373.10 7.061 36.26
1024 256 4096 2.759 371.21 7.308 35.03
1024 256 5120 2.775 369.05 7.395 34.62
1024 256 6144 2.791 366.83 7.461 34.31
1024 256 7168 2.807 364.78 7.516 34.06
1024 256 8192 2.823 362.79 7.584 33.76

Prior version:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.670 613.09 8.269 30.96
1024 256 1024 1.520 673.85 8.424 30.39
1024 256 2048 1.537 666.18 8.510 30.08
1024 256 3072 1.553 659.55 8.565 29.89
1024 256 4096 1.570 652.09 8.817 29.03
1024 256 5120 1.586 645.84 8.908 28.74
1024 256 6144 1.602 639.37 8.974 28.53
1024 256 7168 1.617 633.22 9.029 28.35

I can check the other commits between and make sure it's this one.

It actually started with this one: 14492bf

@ikawrakow
Copy link
Owner Author

@Ph0rk0z

Do you see this in your output

ggml_backend_cuda_context: a context for device 0 already exists?

This shouldn't be there. Exactly what did you do that it is there?

@Ph0rk0z
Copy link

Ph0rk0z commented Mar 10, 2026

I didn't do anything:

CUDA_VISIBLE_DEVICES=0,1,2,3 numactl --interleave=all ./bin/llama-sweep-bench \
    -m Devstral-2-123B-Instruct-2512-GGUF-UD-Q4_K_XL/Devstral-2-123B-Instruct-2512-UD-Q4_K_XL-00001-of-00002.gguf \
    -t 48 \
    -c 81920 \
    -ts 24,24,24,21 \
    --numa distribute \
    --host 192.168.1.211 \
    -ngl 99 \
    -ctk q8_0 \
    -ctv q8_0 \
    -fa 1 \
    -ub 1024 \
    -sm graph \
    -gr \
    -grt bf16 \
    --no-mmap \
    -cuda enable-p2p=1,fusion=1

Loading the qwen as well to see if it still crashes in graph.

@ikawrakow
Copy link
Owner Author

What happens if you remove the --no-mmap?

@ubergarm
Copy link
Contributor

ubergarm commented Mar 10, 2026

@ikawrakow

Yes, omitting --no-mmap fixes my issue I observe over here: #1392 (comment)

I was confused and thought this was the other PR discussion hah.. When using --no-mmap now I also see:

 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 0->1
ggml_backend_cuda_context: a context for device 1 already exists?
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->0

@Ph0rk0z
Copy link

Ph0rk0z commented Mar 10, 2026

Finally loaded the 397b with the vision commit reverted. I use nommap because loading is slower but output is faster.

It's kinda crazy that this gives this much extra t/s but cuts the prompt in half. Maybe there's a way to have both :P
I didn't check the coherence of the models though.

with commit reverted I have:


 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 0->3
.........~ggml_backend_cuda_context: have 0 graphs
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->0
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->2
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 1->3
..........~ggml_backend_cuda_context: have 0 graphs
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 2->0
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 2->1
 =========================== ggml_cuda_set_peer_access: Enabling Peer Access between Devices 2->3
.........~ggml_backend_cuda_context: have 0 graphs

And briefly

 current device: 3, in function alloc at /home/supermicro/ai/ik_llama.cpp/ggml/src/ggml-cuda.cu:442
  cuMemCreate(&handle, reserve_size, &prop, 0)

graph mode doesn't crash:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
2048 512 0 11.381 179.94 25.017 20.47

With the new commit mistral is bac:

PP TG N_KV T_PP s S_PP t/s T_TG s S_TG t/s
1024 256 0 1.646 621.95 8.264 30.98

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants