Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
a7df116
qwen3next: add architecture support and recurrent-state fixes
Feb 6, 2026
9fbb504
qwen3next: optimize broadcast sub and single-seq ssm conv
Feb 6, 2026
89e9ecf
cuda: build MoE row mapping on device in mul_mat_id
Feb 6, 2026
236633a
cuda: add guarded multi-seq fast path for ssm_conv
Feb 6, 2026
c767cfa
docs: update qwen3next perf report for cuda MoE/SSM tuning
Feb 6, 2026
e64b433
cuda: reduce qwen3next moe/ssm sync overhead and refresh eval
Feb 6, 2026
6db8dc8
qwen3next: split cpu/cuda eval builds and tune PP scheduling
Feb 7, 2026
fffd27e
qwen3next: harden seq-state flow and support optional dense FFN layers
Feb 7, 2026
a1163d0
qwen3next: trim delta-net graph overhead in chunking path
Feb 7, 2026
0e3891b
qwen3next: remove redundant v_conv cont in delta path
Feb 7, 2026
43edfa2
qwen3next: avoid extra cont on linear attention output
Feb 7, 2026
de5bf44
qwen3next: drop redundant cont before recurrent state flatten
Feb 7, 2026
5a6c4e8
qwen3next: keep recurrent state in 4d layout through delta path
Feb 7, 2026
6dd990d
qwen3next: add fused delta-net op and wire model path
Feb 7, 2026
ed0565f
tests: add backend-op coverage for ggml_delta_net
Feb 7, 2026
b33cef6
qwen3next: add runtime switch for fused delta-net path
Feb 8, 2026
81e788e
docs: refresh qwen3next perf review and benchmark matrix
Feb 8, 2026
9930f4d
qwen3next: default fused delta-net off and document quality checks
Feb 8, 2026
143e88a
qwen3next: add decode-only fused delta mode
Feb 8, 2026
64099e7
qwen3next: make fused delta safe by default and fix fused tensor layout
Feb 8, 2026
343e335
qwen3next: warn when forcing fused decode mode
Feb 8, 2026
44db394
qwen3next: add fused-delta regression runner script
Feb 8, 2026
55270b0
qwen3next: integrate fused regression into eval harness
Feb 8, 2026
670434e
qwen3next: clean up chunked delta-net shape handling
Feb 8, 2026
691df60
qwen3next: add absolute sanity guards to fused regression
Feb 8, 2026
a822db6
qwen3next: add unified regression runner script
Feb 8, 2026
627d469
qwen3next: disable flash-attn for cpu-only contexts
Feb 8, 2026
bd0dd78
docs: reconcile qwen3next status and remaining upstream gaps
Feb 8, 2026
b5c9554
common: add qwen3next fused-delta runtime flag
Feb 8, 2026
eef360a
cuda: add qwen3next delta-net kernel dispatch override
Feb 8, 2026
69529d3
docs: update qwen3next quality and serving baseline findings
Feb 8, 2026
48e0e35
qwen3next: keep fused delta on safe path and remove PR artifacts
Feb 9, 2026
9241164
qwen3next: align autoregressive delta-net decode layout
Feb 9, 2026
6009557
Revert "qwen3next: align autoregressive delta-net decode layout"
Feb 9, 2026
113ad6c
cuda: port solve-tri fast-paths for qwen3next delta-net
Feb 9, 2026
6f21f24
qwen3next: add fused-delta runtime flag and drop env toggle
Feb 9, 2026
f1f6da7
qwen3next: make fused delta single-flag and default on
Feb 9, 2026
4ab02c9
Account for GPU arch differences
Feb 10, 2026
117ff5d
Revert "cuda: build MoE row mapping on device in mul_mat_id"
Feb 10, 2026
6d8fb70
qwen3next: drop non-essential MoE scheduling and split heuristics
Feb 10, 2026
ed10c94
qwen3next: avoid generic ggml_sub broadcast changes
Feb 10, 2026
4e55ac7
llama: restore only_active_experts log message
Feb 10, 2026
71035bf
Merge branch 'ikawrakow:main' into main
YurkoHoshko Feb 10, 2026
012377b
Remove unnecessary hacks, disable fusion for now.
Feb 11, 2026
b7781f2
qwen3next: port hybrid recurrent state memory semantics
Feb 11, 2026
d7b6358
qwen3next: clean up recurrent state slot plumbing
Feb 11, 2026
aaa1b12
qwen3next: fix hybrid V-cache layout plumbing
Feb 11, 2026
cac3c5f
qwen3next: guard recurrent state slots against kv capacity
Feb 11, 2026
c771416
qwen3next: persist recurrent state in session data
Feb 11, 2026
dd690cb
qwen3next: drop unused fused-delta builder path
Feb 11, 2026
3470e8a
qwen3next: remove unused fused-delta CLI/context plumbing
Feb 11, 2026
cb99ab7
ggml: remove unused DELTA_NET operator stack
Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <climits>
#include <cmath>
#include <codecvt>
#include <cstdlib>
#include <cstdarg>
#include <cstring>
#include <ctime>
Expand Down Expand Up @@ -487,6 +488,7 @@ void gpt_params_parse_from_env(gpt_params & params) {
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
get_env("LLAMA_ARG_HOST", params.hostname);
get_env("LLAMA_ARG_PORT", params.port);

}

bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
Expand Down
80 changes: 80 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ extern "C" {
GGML_OP_LOG,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_CUMSUM,
GGML_OP_MEAN,
GGML_OP_ARGMAX,
GGML_OP_REPEAT,
Expand All @@ -611,6 +612,7 @@ extern "C" {
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_L2_NORM,
GGML_OP_FUSED_RMS_NORM,
GGML_OP_FUSED_MUL_UNARY,
GGML_OP_MULTI_ADD,
Expand Down Expand Up @@ -653,6 +655,8 @@ extern "C" {
GGML_OP_PAD,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_TRI,
GGML_OP_FILL,
GGML_OP_ARGSORT,
GGML_OP_ARGSORT_THRESH,
GGML_OP_GROUPED_TOPK,
Expand All @@ -671,6 +675,7 @@ extern "C" {
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_SOLVE_TRI,
GGML_OP_UNARY,

GGML_OP_MAP_UNARY,
Expand Down Expand Up @@ -710,6 +715,8 @@ extern "C" {
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_SOFTPLUS,
GGML_UNARY_OP_SWIGLU,
GGML_UNARY_OP_SWIGLU_OAI,
GGML_UNARY_OP_GELU,
Expand Down Expand Up @@ -739,6 +746,13 @@ extern "C" {
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};

enum ggml_tri_type {
GGML_TRI_TYPE_LOWER,
GGML_TRI_TYPE_UPPER,
GGML_TRI_TYPE_LOWER_DIAG,
GGML_TRI_TYPE_UPPER_DIAG,
};

// ggml object
struct ggml_object {
size_t offs;
Expand Down Expand Up @@ -1189,6 +1203,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_softplus(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_softplus_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);

// return scalar
GGML_API struct ggml_tensor * ggml_sum(
struct ggml_context * ctx,
Expand All @@ -1199,6 +1221,10 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a);

// mean along rows
GGML_API struct ggml_tensor * ggml_mean(
struct ggml_context * ctx,
Expand All @@ -1217,6 +1243,15 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

// repeat a to specified shape
GGML_API struct ggml_tensor * ggml_repeat_4d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
int64_t ne1,
int64_t ne2,
int64_t ne3);

// sums repetitions in a into shape of b
GGML_API struct ggml_tensor * ggml_repeat_back(
struct ggml_context * ctx,
Expand Down Expand Up @@ -1455,6 +1490,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_exp(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_exp_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);

// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
Expand Down Expand Up @@ -1514,6 +1557,17 @@ extern "C" {
int n_groups,
float eps);

// l2 normalize along rows
GGML_API struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);

GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);

// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back(
Expand Down Expand Up @@ -2283,6 +2337,23 @@ extern "C" {
int dim,
int max_period);

// convert matrix to triangular form by zeroing values outside selected half
GGML_API struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type type);

// fill tensor with constant c
GGML_API struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);

GGML_API struct ggml_tensor * ggml_fill_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float c);

// sort rows
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
Expand Down Expand Up @@ -2426,6 +2497,15 @@ extern "C" {
struct ggml_tensor * pw,
struct ggml_tensor * ph);

// Solve Ax = B where A is triangular
GGML_API struct ggml_tensor * ggml_solve_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool left,
bool lower,
bool uni);

// custom operators

typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
Expand Down
75 changes: 69 additions & 6 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/dmmv.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/fill.cuh"
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
Expand All @@ -46,10 +48,13 @@
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/multiadd.cuh"
#include "ggml-cuda/hadamard.cuh"
#include "ggml-cuda/reduce.cuh"
#include "ggml-cuda/tri.cuh"

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -2822,11 +2827,6 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
return i;
}

std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

ggml_tensor src0_1_row = *src0_1;
ggml_tensor src0_2_row; if (src0_2) src0_2_row = *src0_2;
ggml_tensor src1_row = *src1;
Expand Down Expand Up @@ -2917,7 +2917,10 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = moe_counts[i02];

if (num_src1_rows == 0) continue;
if (num_src1_rows == 0) {
continue;
}

size_t mapping_offset = cum_moe_counts[i02];

if (use_quantized_src1) {
Expand Down Expand Up @@ -3305,6 +3308,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_cuda_op_softplus(ctx, dst);
break;
default:
return -1;
}
Expand Down Expand Up @@ -3339,6 +3348,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUP_NORM:
ggml_cuda_op_group_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_cuda_op_l2_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
if (fusion && i + 2 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
Expand Down Expand Up @@ -3554,6 +3566,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_sum_rows(ctx, dst);
}
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_ARGSORT:
if (fusion && i + 5 < cgraph->n_nodes &&
cgraph->nodes[i+1]->op == GGML_OP_VIEW &&
Expand All @@ -3573,6 +3588,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUPED_TOPK:
ggml_cuda_op_grouped_topk(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_FILL:
ggml_cuda_op_fill(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
Expand Down Expand Up @@ -4149,6 +4176,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
return ggml_is_contiguous(op->src[0]);
default:
return false;
Expand Down Expand Up @@ -4342,6 +4371,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_L2_NORM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
Expand Down Expand Up @@ -4389,6 +4420,38 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_CUMSUM:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_TRI:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[0]->type == op->type;
case GGML_OP_FILL:
return ggml_is_contiguous(op) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_SOLVE_TRI:
return ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
op->src[0]->ne[0] == op->src[0]->ne[1] &&
op->src[0]->ne[1] == op->src[1]->ne[1] &&
op->src[0]->ne[2] == op->src[1]->ne[2] &&
op->src[0]->ne[3] == op->src[1]->ne[3];
case GGML_OP_SSM_CONV:
return op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_F32 &&
op->src[2]->type == GGML_TYPE_F32 &&
op->src[3]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_F32 &&
op->src[0]->nb[0] == sizeof(float) &&
op->src[1]->nb[0] == sizeof(float) &&
op->src[2]->nb[0] == sizeof(float) &&
op->src[3]->nb[0] == sizeof(int32_t) &&
op->src[2]->ne[0] == op->src[0]->ne[0] + 1 &&
op->src[2]->ne[1] == op->src[0]->ne[1] &&
op->src[1]->ne[0] == op->src[0]->ne[1] &&
op->src[3]->ne[0] == op->src[0]->ne[2];
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
Expand Down
Loading