[Refactor][Triton] Move reject sample triton kernels into ops/triton#5324
[Refactor][Triton] Move reject sample triton kernels into ops/triton#5324wangxiyuan merged 2 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the rejection sampling Triton kernels by moving them into a new, dedicated file vllm_ascend/ops/triton/reject_sample.py. This is a good change for code organization and modularity. The calling code in vllm_ascend/sample/rejection_sampler.py has been updated accordingly.
During the review, I identified two critical bugs within the newly added (but previously existing) Triton kernels. Both rejection_greedy_sample_triton and rejection_random_sample_kernel incorrectly use Python for loops with runtime-dependent bounds. This is not supported by Triton and will cause errors. I have provided detailed comments and code suggestions to resolve these issues by using tl.range with compile-time constant bounds and appropriate masking.
| for i in range(num_tokens1): | ||
| if not rejected: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i) | ||
| target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i) | ||
| tl.store( | ||
| output_token_ids_ptr + position * (max_spec_len + 1) + i, | ||
| target_argmax_id, | ||
| ) | ||
| if draft_token_id != target_argmax_id: | ||
| # Reject. | ||
| rejected = True |
There was a problem hiding this comment.
The loop for i in range(num_tokens1): uses a runtime variable num_tokens1 as its bound. In Triton, Python for loops are unrolled at compile time and require their bounds to be compile-time constants. Using a runtime variable here is incorrect and can lead to errors or unexpected behavior.
To fix this, you should use tl.range with a compile-time constant bound, such as max_spec_len, and mask the operations inside the loop. For max_spec_len to be a compile-time constant, you also need to remove it from the do_not_specialize list in the @triton.jit decorator for this function (line 75).
| for i in range(num_tokens1): | |
| if not rejected: | |
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i) | |
| target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i) | |
| tl.store( | |
| output_token_ids_ptr + position * (max_spec_len + 1) + i, | |
| target_argmax_id, | |
| ) | |
| if draft_token_id != target_argmax_id: | |
| # Reject. | |
| rejected = True | |
| for i in tl.range(0, max_spec_len): | |
| if i < num_tokens1 and not rejected: | |
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i) | |
| target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i) | |
| tl.store( | |
| output_token_ids_ptr + position * (max_spec_len + 1) + i, | |
| target_argmax_id, | |
| ) | |
| if draft_token_id != target_argmax_id: | |
| # Reject. | |
| rejected = True |
| for pos in range(num_draft_tokens): | ||
| if not rejected: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| if NO_DRAFT_PROBS: | ||
| draft_prob = 1 | ||
| else: | ||
| draft_prob = tl.load(draft_probs_ptr + | ||
| (start_idx + pos) * vocab_size + | ||
| draft_token_id) | ||
| target_prob = tl.load(target_probs_ptr + | ||
| (start_idx + pos) * vocab_size + | ||
| draft_token_id) | ||
| uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) | ||
| if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: | ||
| # Accept | ||
| token_id = draft_token_id | ||
| else: | ||
| # Reject. Use recovered token | ||
| rejected = True | ||
| token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) | ||
| tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, | ||
| token_id) |
There was a problem hiding this comment.
The loop for pos in range(num_draft_tokens): uses a runtime variable num_draft_tokens as its bound. In Triton, Python for loops are unrolled at compile time and require their bounds to be compile-time constants. This is a critical issue that can lead to errors.
To fix this, you should use tl.range with a compile-time constant bound, like max_spec_len, and add a mask pos < num_draft_tokens inside the loop. You will also need to make max_spec_len a compile-time constant by removing it from the do_not_specialize list in the @triton.jit decorator for this function (line 131).
for pos in tl.range(0, max_spec_len):
if pos < num_draft_tokens and not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept
token_id = draft_token_id
else:
# Reject. Use recovered token
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
3469827 to
b326cbd
Compare
7d6eb19 to
6340897
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
6340897 to
a9b8f1f
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: whx-sjtu <2952154980@qq.com>
a9b8f1f to
af87f06
Compare
…to FIA_rebase * 'main' of https://github.com/vllm-project/vllm-ascend: (88 commits) [1/N] Refactor nightly test structure (vllm-project#5479) Docs: Remove deprecated --task parameter for embedding models (vllm-project#5257) Revert "moe_gating_top_k" (vllm-project#5512) [Doc] Fix issue link for 0.12.0 (vllm-project#5500) [CI]update triton ascend version (vllm-project#5392) moe_gating_top_k (vllm-project#5271) [refactor] refactor model runner capture model (vllm-project#5230) Update corresponding vllm commit ID to 12 29 (vllm-project#5475) [Kernel]update csrc cmakelist for open-source cann (vllm-project#5458) [OP] add custom op aclnnMoeInitRoutingCustom (vllm-project#5251) [Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (vllm-project#5176) [Refactor][Triton] Move reject sample triton kernels into ops/triton (vllm-project#5324) [Feature] support eager mode in model runner v2 (vllm-project#5210) [feature] fia support sliding windows (vllm-project#5239) Optimize some rejectsampler functions to make npu op launch non-blocking (vllm-project#4587) [Feature] Support to use fullgraph with eagle (vllm-project#5118) [EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (vllm-project#5311) [Refactor]6/N Extract common code of class AscendMLAImpl (vllm-project#5314) [Refactor] cache cos/sin in mla & remove parameter model in builder. (vllm-project#5277) update vllm pin to 12.27 (vllm-project#5412) ...
…llm-project#5324) ### What this PR does / why we need it? This PR moves reject sample related triton kernels into `ops/triton`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with existing test. - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@5fbfa8d --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
…llm-project#5324) ### What this PR does / why we need it? This PR moves reject sample related triton kernels into `ops/triton`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with existing test. - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@5fbfa8d --------- Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…llm-project#5324) ### What this PR does / why we need it? This PR moves reject sample related triton kernels into `ops/triton`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with existing test. - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@5fbfa8d --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
…llm-project#5324) ### What this PR does / why we need it? This PR moves reject sample related triton kernels into `ops/triton`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with existing test. - vLLM version: release/v0.13.0 - vLLM main: vllm-project/vllm@5fbfa8d --------- Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
What this PR does / why we need it?
This PR moves reject sample related triton kernels into
ops/triton.Does this PR introduce any user-facing change?
No
How was this patch tested?
CI passed with existing test.