Skip to content

[Refactor][Triton] Move reject sample triton kernels into ops/triton#5324

Merged
wangxiyuan merged 2 commits intovllm-project:mainfrom
whx-sjtu:move_reject_triton
Dec 29, 2025
Merged

[Refactor][Triton] Move reject sample triton kernels into ops/triton#5324
wangxiyuan merged 2 commits intovllm-project:mainfrom
whx-sjtu:move_reject_triton

Conversation

@whx-sjtu
Copy link
Copy Markdown
Collaborator

@whx-sjtu whx-sjtu commented Dec 24, 2025

What this PR does / why we need it?

This PR moves reject sample related triton kernels into ops/triton.

Does this PR introduce any user-facing change?

No

How was this patch tested?

CI passed with existing test.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the rejection sampling Triton kernels by moving them into a new, dedicated file vllm_ascend/ops/triton/reject_sample.py. This is a good change for code organization and modularity. The calling code in vllm_ascend/sample/rejection_sampler.py has been updated accordingly.

During the review, I identified two critical bugs within the newly added (but previously existing) Triton kernels. Both rejection_greedy_sample_triton and rejection_random_sample_kernel incorrectly use Python for loops with runtime-dependent bounds. This is not supported by Triton and will cause errors. I have provided detailed comments and code suggestions to resolve these issues by using tl.range with compile-time constant bounds and appropriate masking.

Comment on lines +109 to +119
for i in range(num_tokens1):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The loop for i in range(num_tokens1): uses a runtime variable num_tokens1 as its bound. In Triton, Python for loops are unrolled at compile time and require their bounds to be compile-time constants. Using a runtime variable here is incorrect and can lead to errors or unexpected behavior.

To fix this, you should use tl.range with a compile-time constant bound, such as max_spec_len, and mask the operations inside the loop. For max_spec_len to be a compile-time constant, you also need to remove it from the do_not_specialize list in the @triton.jit decorator for this function (line 75).

Suggested change
for i in range(num_tokens1):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
for i in tl.range(0, max_spec_len):
if i < num_tokens1 and not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True

Comment on lines +158 to +179
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept
token_id = draft_token_id
else:
# Reject. Use recovered token
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The loop for pos in range(num_draft_tokens): uses a runtime variable num_draft_tokens as its bound. In Triton, Python for loops are unrolled at compile time and require their bounds to be compile-time constants. This is a critical issue that can lead to errors.

To fix this, you should use tl.range with a compile-time constant bound, like max_spec_len, and add a mask pos < num_draft_tokens inside the loop. You will also need to make max_spec_len a compile-time constant by removing it from the do_not_specialize list in the @triton.jit decorator for this function (line 131).

    for pos in tl.range(0, max_spec_len):
        if pos < num_draft_tokens and not rejected:
            draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
            if NO_DRAFT_PROBS:
                draft_prob = 1
            else:
                draft_prob = tl.load(draft_probs_ptr +
                                     (start_idx + pos) * vocab_size +
                                     draft_token_id)
            target_prob = tl.load(target_probs_ptr +
                                  (start_idx + pos) * vocab_size +
                                  draft_token_id)
            uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
            if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
                # Accept
                token_id = draft_token_id
            else:
                # Reject. Use recovered token
                rejected = True
                token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
            tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
                     token_id)

@whx-sjtu whx-sjtu added ready read for review ready-for-test start test by label for PR labels Dec 24, 2025
@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@whx-sjtu whx-sjtu force-pushed the move_reject_triton branch 5 times, most recently from 7d6eb19 to 6340897 Compare December 26, 2025 01:42
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@whx-sjtu
Copy link
Copy Markdown
Collaborator Author

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
@wangxiyuan wangxiyuan merged commit 28b7614 into vllm-project:main Dec 29, 2025
17 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Dec 31, 2025
…to FIA_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (88 commits)
  [1/N] Refactor nightly test structure (vllm-project#5479)
  Docs: Remove deprecated --task parameter for embedding models (vllm-project#5257)
  Revert "moe_gating_top_k" (vllm-project#5512)
  [Doc] Fix issue link for 0.12.0 (vllm-project#5500)
  [CI]update triton ascend version (vllm-project#5392)
  moe_gating_top_k (vllm-project#5271)
  [refactor] refactor model runner capture model (vllm-project#5230)
  Update corresponding vllm commit ID to 12 29 (vllm-project#5475)
  [Kernel]update csrc cmakelist for open-source cann (vllm-project#5458)
  [OP] add custom op aclnnMoeInitRoutingCustom (vllm-project#5251)
  [Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (vllm-project#5176)
  [Refactor][Triton] Move reject sample triton kernels into ops/triton (vllm-project#5324)
  [Feature] support eager mode in model runner v2 (vllm-project#5210)
  [feature] fia support sliding windows (vllm-project#5239)
  Optimize some rejectsampler functions to make npu op launch non-blocking (vllm-project#4587)
  [Feature] Support to use fullgraph with eagle (vllm-project#5118)
  [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (vllm-project#5311)
  [Refactor]6/N Extract common code of class AscendMLAImpl (vllm-project#5314)
  [Refactor] cache cos/sin in mla & remove parameter model in builder. (vllm-project#5277)
  update vllm pin to 12.27 (vllm-project#5412)
  ...
shenchuxiaofugui pushed a commit to shenchuxiaofugui/vllm-ascend that referenced this pull request Dec 31, 2025
…llm-project#5324)

### What this PR does / why we need it?
This PR moves reject sample related triton kernels into `ops/triton`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.


- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…llm-project#5324)

### What this PR does / why we need it?
This PR moves reject sample related triton kernels into `ops/triton`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…llm-project#5324)

### What this PR does / why we need it?
This PR moves reject sample related triton kernels into `ops/triton`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.


- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…llm-project#5324)

### What this PR does / why we need it?
This PR moves reject sample related triton kernels into `ops/triton`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@5fbfa8d

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants