-
Notifications
You must be signed in to change notification settings - Fork 38
Fix private memory size too large in sample_recovered_tokens_kernel #115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,3 +6,4 @@ | |
| from . import device_allocator | ||
| from . import model_executor | ||
| from . import oot | ||
| from . import sample | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from . import rejection_sampler |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,86 @@ | ||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||
| from vllm.triton_utils import tl, triton | ||||||||||||||
|
|
||||||||||||||
| import vllm.v1.sample.rejection_sampler | ||||||||||||||
|
|
||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @triton.jit | ||||||||||||||
| def sample_recovered_tokens_kernel( | ||||||||||||||
| output_token_ids_ptr, # [num_tokens] | ||||||||||||||
| cu_num_draft_tokens_ptr, # [batch_size] | ||||||||||||||
| draft_token_ids_ptr, # [num_tokens] | ||||||||||||||
| draft_probs_ptr, # [num_tokens, vocab_size] or None | ||||||||||||||
| target_probs_ptr, # [num_tokens, vocab_size] | ||||||||||||||
| q_ptr, # [batch_size, vocab_size] | ||||||||||||||
| vocab_size, | ||||||||||||||
| PADDED_VOCAB_SIZE: tl.constexpr, | ||||||||||||||
| NO_DRAFT_PROBS: tl.constexpr, | ||||||||||||||
| BLOCK_SIZE: tl.constexpr = 1024, | ||||||||||||||
| ): | ||||||||||||||
| req_idx = tl.program_id(0) | ||||||||||||||
| if req_idx == 0: | ||||||||||||||
| start_idx = 0 | ||||||||||||||
| else: | ||||||||||||||
| start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) | ||||||||||||||
| end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) | ||||||||||||||
| num_draft_tokens = end_idx - start_idx | ||||||||||||||
|
|
||||||||||||||
| # Early exit for out-of-range positions. | ||||||||||||||
| pos = tl.program_id(1) | ||||||||||||||
| if pos >= num_draft_tokens: | ||||||||||||||
| return | ||||||||||||||
|
|
||||||||||||||
| max_prob = -float('inf') | ||||||||||||||
| best_token_id = 0 | ||||||||||||||
|
|
||||||||||||||
| for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE): | ||||||||||||||
| block_end = min(block_start + BLOCK_SIZE, vocab_size) | ||||||||||||||
|
|
||||||||||||||
| vocab_offset = tl.arange(0, BLOCK_SIZE) | ||||||||||||||
| mask = vocab_offset < block_end - block_start | ||||||||||||||
|
|
||||||||||||||
| if NO_DRAFT_PROBS: | ||||||||||||||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||||||||||||||
| prob = tl.load( | ||||||||||||||
| target_probs_ptr + (start_idx + pos) * vocab_size + | ||||||||||||||
| block_start + vocab_offset, | ||||||||||||||
| mask=(mask & (vocab_offset + block_start != draft_token_id)), | ||||||||||||||
| other=0) | ||||||||||||||
|
|
||||||||||||||
| else: | ||||||||||||||
| draft_prob = tl.load(draft_probs_ptr + | ||||||||||||||
| (start_idx + pos) * vocab_size + block_start + | ||||||||||||||
| vocab_offset, | ||||||||||||||
| mask=mask, | ||||||||||||||
| other=0) | ||||||||||||||
| target_prob = tl.load(target_probs_ptr + | ||||||||||||||
| (start_idx + pos) * vocab_size + | ||||||||||||||
| block_start + vocab_offset, | ||||||||||||||
| mask=mask, | ||||||||||||||
| other=0) | ||||||||||||||
| prob = tl.maximum(target_prob - draft_prob, 0) | ||||||||||||||
|
|
||||||||||||||
| # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because | ||||||||||||||
| # `tl.argmax` will select the maximum value. | ||||||||||||||
|
|
||||||||||||||
| q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset, | ||||||||||||||
| mask=mask, | ||||||||||||||
| other=float("-inf")) | ||||||||||||||
|
Comment on lines
+68
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency and to prevent potential 32-bit integer overflows, the offset calculation for
Suggested change
Comment on lines
+34
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR title and description mention fixing The pointer offset calculations like This suggestion ensures 64-bit arithmetic for offset calculations by explicitly casting to token_idx = start_idx + pos
# Ensure 64-bit arithmetic for large offsets to prevent overflow.
probs_base_offset = token_idx.to(tl.int64) * vocab_size
q_base_offset = req_idx.to(tl.int64) * vocab_size
max_prob = -float('inf')
best_token_id = 0
for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
block_end = min(block_start + BLOCK_SIZE, vocab_size)
vocab_offset = tl.arange(0, BLOCK_SIZE)
mask = vocab_offset < block_end - block_start
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
prob = tl.load(
target_probs_ptr + probs_base_offset + block_start +
vocab_offset,
mask=(mask & (vocab_offset + block_start != draft_token_id)),
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + probs_base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
target_prob = tl.load(target_probs_ptr + probs_base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + q_base_offset + block_start + vocab_offset,
mask=mask,
other=float("-inf")) |
||||||||||||||
|
|
||||||||||||||
| # recovered_id = tl.argmax(prob / q, axis=-1) | ||||||||||||||
| # calc block prob and token ID | ||||||||||||||
| block_prob = prob / q | ||||||||||||||
| block_max_prob = tl.max(block_prob, axis=-1) | ||||||||||||||
| block_best_token_id = tl.argmax(block_prob, axis=-1) + block_start | ||||||||||||||
|
|
||||||||||||||
| # update token ID | ||||||||||||||
| max_prob = tl.maximum(max_prob, block_max_prob) | ||||||||||||||
| best_token_id = tl.where(block_max_prob >= max_prob, | ||||||||||||||
| block_best_token_id, best_token_id) | ||||||||||||||
|
|
||||||||||||||
| tl.store(output_token_ids_ptr + start_idx + pos, best_token_id) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| vllm.v1.sample.rejection_sampler.sample_recovered_tokens_kernel = sample_recovered_tokens_kernel | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer offset calculation
(start_idx + pos) * vocab_sizeis repeated and may suffer from 32-bit integer overflow ifstart_idxis loaded as a 32-bit integer and multiplied by a largevocab_size. This can lead to incorrect memory access and is a likely cause for thepointer value too large to fit in 32 biterror.To ensure correctness and improve readability, it's better to calculate the base offset once outside the loop, explicitly casting to
tl.int64to prevent any potential overflow.