Skip to content

Add optional router_weights input to QMoE for separate selection/aggregation routing#27687

Merged
tianleiwu merged 6 commits intomainfrom
copilot/feature-support-noaux-tc-routing
Mar 19, 2026
Merged

Add optional router_weights input to QMoE for separate selection/aggregation routing#27687
tianleiwu merged 6 commits intomainfrom
copilot/feature-support-noaux-tc-routing

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Mar 16, 2026

Description

Adds optional input router_weights (index 14) to com.microsoft.QMoE to decouple Top-K expert selection from output aggregation weighting.

When router_weights is provided:

  • router_probs → Top-K expert selection only
  • router_weights → values gathered at selected expert indices used as mixing weights

When omitted, existing softmax-of-router_probs behavior is preserved (backward compatible).

Changes:

  • Schema (contrib_defs.cc): New optional input 14 router_weights, type T, shape (num_tokens, num_experts)
  • CPU provider (moe_quantization_cpu.cc): Implements the separate routing path with MLFloat16/float support and optional normalize_routing_weights normalization
  • CUDA provider (moe_quantization.cc): Reads input, enforces not-implemented if provided
  • WebGPU provider (qmoe.cc): Same not-implemented guard
  • Tests (moe_test.cc): QMoETest_CPU_RouterWeights covering both normalized and unnormalized paths with non-zero expected outputs via FC2 bias to validate correct aggregation weights
  • Docs (OperatorKernels.md): Updated CPU and CUDA entries

This pattern matches DeepSeek-V2/V3/R1 routing where sigmoid(logits) is used for aggregation while logits + bias with group masking drives selection:

# DeepSeek-style: different tensors for selection vs aggregation
topk_indices = torch.topk(scores_for_choice, k=top_k)[1]  # selection from modified logits
topk_weights = router_logits.gather(1, topk_indices)        # aggregation from original sigmoid

Motivation and Context

QMoE previously required the same tensor for both routing and weighting, blocking DeepSeek-style noaux_tc MoE models where these are intentionally separate. This unblocks ONNX Runtime export/serving of DeepSeek-V2/V3/R1 MoE architectures.

Original prompt

This section details on the original issue you should resolve

<issue_title>[Feature Request] Support noaux_tc MoE routing in com.microsoft.QMoE via separate router_weights</issue_title>
<issue_description>### Describe the feature request

com.microsoft.QMoE currently accepts a single routing tensor (commonly router_probs) that is used both for:

Top‑K expert selection (routing / dispatch), and
Weighting the outputs of selected experts (aggregation).

This design makes it impossible to represent DeepSeek‑style noaux_tc MoE routing, where different tensors are intentionally used for:

  • expert selection (Top‑K routing), and
  • expert output weighting (mixing).

This issue proposes adding an optional input router_weights to com.microsoft.QMoE so that:

  • router_probs is used only for Top‑K selection, and
  • router_weights is used only for multiplying / aggregating expert outputs.

The change is backward compatible
This also allows for any other methodology in future where different tensors are used for selection/aggregation

Describe scenario use case

Enables exporting and serving DeepSeek‑V2/V3/R1‑style MoE models in ONNX Runtime</issue_description>

<agent_instructions>Please update operator spec and implement it in CPU provider. For CUDA provider, it is fine to throw not implemented exception for now.

Example Deepseek MoE script can be found in https://github.com/huggingface/transformers/blob/75c836b7853cb65f48ab2ce13cddfb12d14ecf5a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py like the following:

class DeepseekV3MoE(nn.Module):
"""
A mixed expert module containing shared experts.
"""

def __init__(self, config):
    super().__init__()
    self.config = config
    self.experts = DeepseekV3NaiveMoe(config)
    self.gate = DeepseekV3TopkRouter(config)
    self.shared_experts = DeepseekV3MLP(
        config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
    )
    self.n_routed_experts = config.n_routed_experts
    self.n_group = config.n_group
    self.topk_group = config.topk_group
    self.norm_topk_prob = config.norm_topk_prob
    self.routed_scaling_factor = config.routed_scaling_factor
    self.top_k = config.num_experts_per_tok

def route_tokens_to_experts(self, router_logits):
    router_logits = router_logits.sigmoid()
    router_logits_for_choice = router_logits + self.gate.e_score_correction_bias
    group_scores = (
        router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
        .topk(2, dim=-1)[0]
        .sum(dim=-1)
    )
    group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
    group_mask = torch.zeros_like(group_scores)
    group_mask.scatter_(1, group_idx, 1)
    score_mask = (
        group_mask.unsqueeze(-1)
        .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
        .reshape(-1, self.n_routed_experts)
    )
    scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0)
    topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
    topk_weights = router_logits.gather(1, topk_indices)
    if self.norm_topk_prob:
        denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
        topk_weights /= denominator
    topk_weights = topk_weights * self.routed_scaling_factor
    return topk_indices, topk_weights

def forward(self, hidden_states):
    residuals = hidden_states
    orig_shape = hidden_states.shape
    router_logits = self.gate(hidden_states)
    topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
    hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
    hidden_states = hidden_states + self.shared_experts(residuals)
    return hidden_states

</agent_instructions>

Comments on the Issue (you are @copilot in this section)


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 2 commits March 16, 2026 22:11
…tion/aggregation

- Update QMoE schema with new optional input 14 (router_weights)
- Implement CPU provider support: when router_weights is provided, router_probs
  is used only for Top-K expert selection, and router_weights values at selected
  expert indices are used for output aggregation
- Add not-implemented guards for CUDA and WebGPU providers
- Update OperatorKernels.md documentation
- Add test case for QMoE with separate router_weights

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
…on test case

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copilot AI changed the title [WIP] [Feature Request] Support noaux_tc MoE routing in com.microsoft.QMoE Add optional router_weights input to QMoE for separate selection/aggregation routing Mar 16, 2026
Copilot AI requested a review from tianleiwu March 16, 2026 22:14
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/test/contrib_ops/moe_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/moe_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/moe_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/moe_test.cc Outdated
@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu marked this pull request as ready for review March 17, 2026 07:32
@tianleiwu
Copy link
Copy Markdown
Contributor

@copilot, please investigate test errors in CI pipelines and fix it.
For example, like nan output in MoETest.QMoETest_CPU_RouterWeights in Windows CPU CI Pipeline / build_x86_release.

Without swiglu_fusion=1, the default (0) causes fc1_out_features to be
inter_size instead of 2*inter_size. The SwiGLU activation reads
2*inter_size values from the FC1 output buffer, causing out-of-bounds
reads and NaN output on some platforms (e.g., Windows x86).

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copy link
Copy Markdown
Contributor Author

Copilot AI commented Mar 18, 2026

@copilot, please investigate test errors in CI pipelines and fix it.
For example, like nan output in MoETest.QMoETest_CPU_RouterWeights in [Windows CPU CI Pipeline / build_x86_releas...

The NaN was caused by a missing swiglu_fusion=1 attribute in the test. With the default swiglu_fusion=0, fc1_out_features is computed as inter_size instead of 2*inter_size, but the SwiGLU activation reads 2*inter_size values from the FC1 output buffer — an out-of-bounds read that produces NaN on Windows x86.

Fixed in 200c402 by adding cpu_tester.AddAttribute<int64_t>("swiglu_fusion", static_cast<int64_t>(1)); to the test.

Copilot AI requested a review from tianleiwu March 18, 2026 15:20
@tianleiwu tianleiwu requested a review from apsonawane March 18, 2026 21:29
@tianleiwu tianleiwu enabled auto-merge (squash) March 19, 2026 06:18
@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu merged commit 5bba357 into main Mar 19, 2026
89 checks passed
@tianleiwu tianleiwu deleted the copilot/feature-support-noaux-tc-routing branch March 19, 2026 08:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Support noaux_tc MoE routing in com.microsoft.QMoE via separate router_weights

3 participants