Skip to content

[Feature] Add token mask for DispatchGmmCombineDecode operator#5171

Merged
wangxiyuan merged 1 commit intovllm-project:mainfrom
wangqiankun13:add_mc2_mask
Dec 19, 2025
Merged

[Feature] Add token mask for DispatchGmmCombineDecode operator#5171
wangxiyuan merged 1 commit intovllm-project:mainfrom
wangqiankun13:add_mc2_mask

Conversation

@wangqiankun13
Copy link
Copy Markdown
Contributor

@wangqiankun13 wangqiankun13 commented Dec 18, 2025

What this PR does / why we need it?

In this PR, DispatchGmmCombineDecode add an optional input x_active_mask, with which
only token masked True will be dispatched and handle.

Does this PR introduce any user-facing change?

How was this patch tested?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an x_active_mask to the dispatch_gmm_combine_decode operation, which is a valuable optimization to skip computations for inactive (e.g., padding) tokens. The changes are comprehensive, touching the operator definition, tiling logic, kernel implementation, and Python bindings. My review has identified a couple of critical issues in the kernel logic for calculating the number of active tokens that could lead to incorrect behavior, as well as a minor issue with a misleading error message in the tiling logic. Addressing these points will ensure the correctness and robustness of this new feature.

Comment on lines +422 to +423
SumParams params{1, axisBsAlignSize_, axisBS_};
Sum(sumOutTensor, tempTensor, params);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The srcStride parameter of SumParams is expected to be in units of elements, but axisBsAlignSize_ is a byte size. This is likely to cause an incorrect sum calculation for the active tokens. Using the simpler Sum overload that only takes the length should be correct here, as the data is contiguous.

        Sum(sumOutTensor, tempTensor, axisBS_);

Comment on lines +413 to +414
SumParams params{1, axisBsAlignSize_, axisBS_};
Sum(sumOutTensor, maskTmpTensor, params);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The srcStride parameter for SumParams is expected to be in units of elements, but axisBsAlignSize_ is a byte size. This will likely lead to incorrect results when calculating the sum of active tokens. A simpler and more correct approach would be to use the Sum overload that takes the length directly, as the data is contiguous.

    Sum(sumOutTensor, maskTmpTensor, axisBS_);

Comment on lines +136 to +138
OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
"gmm2WeightScale Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
return ge::GRAPH_FAILED);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The error message in this log appears to be a copy-paste error from another check. It refers to gmm2WeightScale Dim0 when it should be referring to xActiveMask Dim0. This could be misleading during debugging.

        OPS_ERR_IF(xActiveMaskDim0 != batchSize, OPS_LOG_E(nodeName,
                    "xActiveMask Dim0 must be batchSize(%u), but current dim is %lu.", batchSize, xActiveMaskDim0),
                    return ge::GRAPH_FAILED);

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

DispatchGmmCombineDecode supports one-dim x_active_mask, with which
only token masked True will be dispatched and handle.

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
@wangqiankun13 wangqiankun13 changed the title Add mc2 mask [Feature] Add token mask for DispatchGmmCombineDecode operator Dec 19, 2025
@wangxiyuan wangxiyuan merged commit 118b0ed into vllm-project:main Dec 19, 2025
25 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Dec 19, 2025
…to eplb_refactor

* 'main' of https://github.com/vllm-project/vllm-ascend: (52 commits)
  [Doc]Add the user_guide doc file regarding fine-grained TP. (vllm-project#5084)
  [pref] qwen3_next add triton ops : fused_sigmoid_gating_delta_rule_update (vllm-project#4818)
  [Feature] Add token mask for DispatchGmmCombineDecode operator (vllm-project#5171)
  [CI] Improve CI (vllm-project#5078)
  [Refactor] remove some metadata variables in attention_v1. (vllm-project#5160)
  Add Qwen3-VL-235B-A22B-Instruct tutorials (vllm-project#5167)
  [Doc] Add a perf tune section (vllm-project#5127)
  [Image] Refactor image build (vllm-project#5175)
  [refactor] refactor weight trans nz and transpose (vllm-project#4878)
  [BugFix]Fix precision issue for LoRA feature (vllm-project#4141)
  【Doc】Deepseekv3.1/R1 doc enhancement (vllm-project#4827)
  support basic long_seq feature st (vllm-project#5140)
  [Bugfix] install trition for test_custom_op (vllm-project#5112)
  [2/N][Pangu][MoE] Remove Pangu Related Code (vllm-project#5130)
  [bugfix] Use FUSED_MC2 MoE comm path for the op `dispatch_ffn_combine` (vllm-project#5156)
  [BugFix] Fix top_p,top_k issue with EAGLE and add top_p,top_k in EAGLE e2e (vllm-project#5131)
  [Doc][P/D] Fix MooncakeConnector's name (vllm-project#5172)
  [Bugfix] Fix in_profile_run in mtp_proposer dummy_run (vllm-project#5165)
  [Doc] Refact benchmark doc (vllm-project#5173)
  [Nightly]  Avoid max_model_len being smaller than the decoder prompt to prevent single-node-accuray-tests from failing (vllm-project#5174)
  ...

Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
chenaoxuan pushed a commit to chenaoxuan/vllm-ascend that referenced this pull request Dec 20, 2025
…project#5171)

### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.


- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…project#5171)

### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…project#5171)

### What this PR does / why we need it?
In this PR, DispatchGmmCombineDecode add an optional input
x_active_mask, with which
only token masked True will be dispatched and handle.

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants