Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_metax/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def sparse_attn_indexer(
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_context_len=max_model_len,
max_model_len,
)
# padded query len
current_device = padded_q_bf16_decode_tokens.device
Expand Down
1 change: 1 addition & 0 deletions vllm_metax/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from . import device_allocator
from . import model_executor
from . import oot
from . import sample
3 changes: 3 additions & 0 deletions vllm_metax/patch/sample/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0

from . import rejection_sampler
86 changes: 86 additions & 0 deletions vllm_metax/patch/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils import tl, triton

import vllm.v1.sample.rejection_sampler

# SPDX-License-Identifier: Apache-2.0


@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr = 1024,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return

max_prob = -float('inf')
best_token_id = 0

for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
block_end = min(block_start + BLOCK_SIZE, vocab_size)

vocab_offset = tl.arange(0, BLOCK_SIZE)
mask = vocab_offset < block_end - block_start

if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size +
block_start + vocab_offset,
mask=(mask & (vocab_offset + block_start != draft_token_id)),
other=0)

else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size + block_start +
vocab_offset,
mask=mask,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
block_start + vocab_offset,
mask=mask,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
Comment on lines +35 to +63

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The pointer offset calculation (start_idx + pos) * vocab_size is repeated and may suffer from 32-bit integer overflow if start_idx is loaded as a 32-bit integer and multiplied by a large vocab_size. This can lead to incorrect memory access and is a likely cause for the pointer value too large to fit in 32 bit error.

To ensure correctness and improve readability, it's better to calculate the base offset once outside the loop, explicitly casting to tl.int64 to prevent any potential overflow.

    token_idx = start_idx + pos
    # Cast to int64 to prevent overflow when calculating pointer offsets.
    base_offset = token_idx.to(tl.int64) * vocab_size

    max_prob = -float('inf')
    best_token_id = 0

    for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
        block_end = min(block_start + BLOCK_SIZE, vocab_size)

        vocab_offset = tl.arange(0, BLOCK_SIZE)
        mask = vocab_offset < block_end - block_start

        if NO_DRAFT_PROBS:
            draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
            prob = tl.load(
                target_probs_ptr + base_offset + block_start + vocab_offset,
                mask=(mask & (vocab_offset + block_start != draft_token_id)),
                other=0)

        else:
            draft_prob = tl.load(draft_probs_ptr + base_offset + block_start +
                                 vocab_offset,
                                 mask=mask,
                                 other=0)
            target_prob = tl.load(target_probs_ptr + base_offset +
                                  block_start + vocab_offset,
                                  mask=mask,
                                  other=0)
            prob = tl.maximum(target_prob - draft_prob, 0)


# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.

q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset,
mask=mask,
other=float("-inf"))
Comment on lines +68 to +70

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency and to prevent potential 32-bit integer overflows, the offset calculation for q_ptr should also use 64-bit integers. While req_idx is likely small, multiplying by a large vocab_size could still pose a risk on some platforms or with very large batches. Using tl.int64 ensures the calculation is safe.

Suggested change
q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset,
mask=mask,
other=float("-inf"))
q = tl.load(q_ptr + req_idx.to(tl.int64) * vocab_size + block_start + vocab_offset,
mask=mask,
other=float("-inf"))

Comment on lines +34 to +70

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The PR title and description mention fixing RuntimeError: Triton Error [MACA]: memory size or pointer value too large to fit in 32 bit. While iterating over the vocabulary in blocks fixes the "private memory size too large" issue, the "pointer value too large" issue can still occur.

The pointer offset calculations like (start_idx + pos) * vocab_size and req_idx * vocab_size can overflow a 32-bit integer if the total number of tokens or the vocabulary size is large, leading to the error.

This suggestion ensures 64-bit arithmetic for offset calculations by explicitly casting to tl.int64 before multiplication. This prevents potential overflows and fully resolves the reported error, making the kernel more robust.

    token_idx = start_idx + pos
    # Ensure 64-bit arithmetic for large offsets to prevent overflow.
    probs_base_offset = token_idx.to(tl.int64) * vocab_size
    q_base_offset = req_idx.to(tl.int64) * vocab_size

    max_prob = -float('inf')
    best_token_id = 0

    for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
        block_end = min(block_start + BLOCK_SIZE, vocab_size)

        vocab_offset = tl.arange(0, BLOCK_SIZE)
        mask = vocab_offset < block_end - block_start

        if NO_DRAFT_PROBS:
            draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
            prob = tl.load(
                target_probs_ptr + probs_base_offset + block_start +
                vocab_offset,
                mask=(mask & (vocab_offset + block_start != draft_token_id)),
                other=0)

        else:
            draft_prob = tl.load(draft_probs_ptr + probs_base_offset +
                                 block_start + vocab_offset,
                                 mask=mask,
                                 other=0)
            target_prob = tl.load(target_probs_ptr + probs_base_offset +
                                  block_start + vocab_offset,
                                  mask=mask,
                                  other=0)
            prob = tl.maximum(target_prob - draft_prob, 0)

            # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
            # `tl.argmax` will select the maximum value.

        q = tl.load(q_ptr + q_base_offset + block_start + vocab_offset,
                    mask=mask,
                    other=float("-inf"))


# recovered_id = tl.argmax(prob / q, axis=-1)
# calc block prob and token ID
block_prob = prob / q
block_max_prob = tl.max(block_prob, axis=-1)
block_best_token_id = tl.argmax(block_prob, axis=-1) + block_start

# update token ID
max_prob = tl.maximum(max_prob, block_max_prob)
best_token_id = tl.where(block_max_prob >= max_prob,
block_best_token_id, best_token_id)

tl.store(output_token_ids_ptr + start_idx + pos, best_token_id)


vllm.v1.sample.rejection_sampler.sample_recovered_tokens_kernel = sample_recovered_tokens_kernel
Loading