Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 231 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2217,5 +2217,236 @@
}
}));

constexpr const char* CausalConvWithState_ver1_doc = R"DOC(
Stateful causal depthwise convolution, generalized to N spatial dimensions.

Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step.
Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation.

The convolution is causal (looks only at current and past positions along the last
spatial dimension) and depthwise (each channel is convolved independently with its own kernel).

Input layout is channels-first: (batch_size, channels, ...).
Weight layout: (channels, 1, k_1, ...) for depthwise convolution.
The carry state stores the last (k-1) positions along the causal axis for incremental decode.

The ndim attribute generalizes the op to 1D, 2D, or 3D spatial dimensions. Causality is
enforced on the last spatial dimension only.
Comment thread
guschmue marked this conversation as resolved.

The optional activation attribute supports fused SiLU/Swish activation.
)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(
CausalConvWithState, 1,
OpSchema()
.SetDoc(CausalConvWithState_ver1_doc)
.Attr("activation",
"Fused activation function. One of: 'silu', 'swish', 'none'. "
"Default is 'none'.",
AttributeProto::STRING,
std::string("none"))
.Attr("ndim",
"Spatial dimensionality: 1, 2, or 3. Default is 1.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0,
"input",
"Input tensor with shape (batch_size, channels, ...). Channels-first layout. "
"Spatial dims: 1D: (L,); 2D: (H, W); 3D: (D, H, W).",
"T")
.Input(1,
"weight",
"Depthwise convolution kernel with shape (channels, 1, k_1, ...). "
"Spatial kernel sizes: (k_1, ..., k_ndim).",
"T")
.Input(2,
"bias",
"Optional per-channel bias with shape (channels).",
"T",
OpSchema::Optional)
.Input(3,
"past_state",
"Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). "
"If not provided, padding is zero.",
"T",
OpSchema::Optional)
.Output(0,
"output",
"Convolution output with same shape as input.",
"T")
.Output(1,
"present_state",
"Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). "
"Contains the last (k-1) values from the virtual input along the causal axis.",
"T")
.TypeConstraint("T",
{"tensor(float)", "tensor(float16)"},
Comment thread
guschmue marked this conversation as resolved.
Outdated
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 1);

// Output 0: same shape as input (batch_size, channels, ...)
propagateShapeFromInputToOutput(ctx, 0, 0);

// Output 1: (batch_size, channels, kernel_size - 1) for ndim=1
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) {
auto& input_shape = getInputShape(ctx, 0);
auto& weight_shape = getInputShape(ctx, 1);
TensorShapeProto state_shape;
*state_shape.add_dim() = input_shape.dim(0); // batch_size
*state_shape.add_dim() = input_shape.dim(1); // channels
// kernel_size - 1 (last kernel dimension for ndim=1)
Comment thread
guschmue marked this conversation as resolved.
Outdated
int last_kernel_dim = weight_shape.dim_size() - 1;
if (weight_shape.dim(last_kernel_dim).has_dim_value()) {
state_shape.add_dim()->set_dim_value(weight_shape.dim(last_kernel_dim).dim_value() - 1);
} else {
state_shape.add_dim(); // unknown
}
updateOutputShape(ctx, 1, state_shape);
}
}));

constexpr const char* LinearAttention_ver1_doc = R"DOC(
Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1).

All inputs use 3D packed format [B, T, H*D]; q_num_heads and kv_num_heads are always
required. The op internally unpacks to 4D for computation.

The update_rule attribute selects the recurrence type:
- "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
- "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t
- "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t
- "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t

where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product.

Semantics: Equivalent to running the recurrent update sequentially for each token,
but may be implemented using chunk-parallel algorithms for GPU efficiency.
)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(
LinearAttention, 1,
OpSchema()
.SetDoc(LinearAttention_ver1_doc)
.Attr("update_rule",
"The update rule for the linear attention recurrence. "
"One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.",
AttributeProto::STRING,
std::string("gated_delta"))

Check warning on line 2336 in onnxruntime/core/graph/contrib_ops/bert_defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/graph/contrib_ops/bert_defs.cc:2336: Add #include <string> for string [build/include_what_you_use] [4]
Comment thread
guschmue marked this conversation as resolved.
.Attr("scale",
"Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads "
"and uses 1/sqrt(d_k). Set explicitly to override.",
AttributeProto::FLOAT,
0.0f)
.Attr("q_num_heads",
"Number of query heads. Always required.",
AttributeProto::INT)
.Attr("kv_num_heads",
"Number of key/value heads. Always required.",
AttributeProto::INT)
.Attr("chunk_size",
"Chunk size for the chunk-parallel WY decomposition during prefill (T>1). "
"Tuning hint; does not affect output correctness.",
AttributeProto::INT,
static_cast<int64_t>(64))
.Input(0,
"query",
"Query vectors with 3D packed shape (B, T, H_q * d_k). "
"Heads are packed into the last dimension.",
"T")
.Input(1,
"key",
"Key vectors with 3D packed shape (B, T, H_kv * d_k). "
"Should be L2-normalized for delta/gated_delta modes.",
"T")
.Input(2,
"value",
"Value vectors with 3D packed shape (B, T, H_kv * d_v).",
"T")
.Input(3,
"past_state",
"Recurrent state from previous step with shape (B, H_kv, d_k, d_v). "
"Always 4D. If not provided, defaults to zeros.",
"T",
OpSchema::Optional)
.Input(4,
"decay",
"Exponential decay gate in log-space. 3D packed shape: "
"(B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or "
"(B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). "
"Required for 'gated' and 'gated_delta' modes.",
"T",
OpSchema::Optional)
.Input(5,
"beta",
"Update rate (sigmoid output). 3D packed shape: "
"(B, T, H_kv) or (B, T, 1). "
"Required for 'delta' and 'gated_delta' modes.",
"T",
OpSchema::Optional)
.Output(0,
"output",
"Attention output with 3D packed shape (B, T, H_q * d_v).",
"T")
.Output(1,
"present_state",
"Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.",
"T")
.TypeConstraint("T",
{"tensor(float)", "tensor(float16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 1);

// Read required attributes
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
int64_t q_num_heads = (q_num_heads_attr && q_num_heads_attr->has_i()) ? q_num_heads_attr->i() : 0;
int64_t kv_num_heads = (kv_num_heads_attr && kv_num_heads_attr->has_i()) ? kv_num_heads_attr->i() : 0;

// Output 0: (B, T, H_q * d_v) — 3D packed
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) {
auto& query_shape = getInputShape(ctx, 0);
auto& value_shape = getInputShape(ctx, 2);
TensorShapeProto output_shape;
*output_shape.add_dim() = query_shape.dim(0); // B
*output_shape.add_dim() = query_shape.dim(1); // T
// H_q * d_v: d_v = value.dim(2) / kv_num_heads, then H_q * d_v
if (value_shape.dim(2).has_dim_value()) {
int64_t d_v = value_shape.dim(2).dim_value() / kv_num_heads;
output_shape.add_dim()->set_dim_value(q_num_heads * d_v);
Comment thread
guschmue marked this conversation as resolved.
Outdated
} else {
output_shape.add_dim(); // unknown
}
updateOutputShape(ctx, 0, output_shape);
}

// Output 1: present_state shape (B, H_kv, d_k, d_v) — 4D
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) {
auto& query_shape = getInputShape(ctx, 0);
auto& value_shape = getInputShape(ctx, 2);
TensorShapeProto state_shape;
*state_shape.add_dim() = query_shape.dim(0); // B
state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv
// d_k = query.dim(2) / q_num_heads
if (query_shape.dim(2).has_dim_value()) {
state_shape.add_dim()->set_dim_value(query_shape.dim(2).dim_value() / q_num_heads);
} else {
state_shape.add_dim();
}
// d_v = value.dim(2) / kv_num_heads
if (value_shape.dim(2).has_dim_value()) {
state_shape.add_dim()->set_dim_value(value_shape.dim(2).dim_value() / kv_num_heads);
} else {
state_shape.add_dim();
}
updateOutputShape(ctx, 1, state_shape);
} else if (hasInputShape(ctx, 3)) {
propagateShapeFromInputToOutput(ctx, 3, 1);
}
}));

} // namespace contrib
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConvWithState);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad);
Expand Down Expand Up @@ -199,6 +201,8 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConvWithState)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad)>());
Expand Down
Loading
Loading