Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
rejection_random_sample_kernel as original_rejection_random_sample_kernel

from vllm_ascend.ops.triton.reject_sample import (
cal_grid_and_block_size, rejection_random_sample_kernel)
cal_grid_and_block_size, rejection_random_sample_block_verify_kernel,
rejection_random_sample_kernel)
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.sample.rejection_sampler import \
rejection_random_sample_block_verify_pytorch


@pytest.fixture(scope="function", autouse=True)
Expand Down Expand Up @@ -93,3 +96,134 @@ def test_rejection_random_sample(max_spec_len, vocab_size, batch_size):
BLOCK_SIZE=block_size)
torch.npu.synchronize()
assert torch.equal(original_output_token_ids, output_token_ids)


DEVICE = "npu"
BATCH_SIZE = 7
MAX_SPEC_LEN = 3
VOCAB_SIZE = 5
CU_NUM_DRAFT_TOKENS = torch.tensor([2, 2, 5, 8, 11, 14, 15],
dtype=torch.int32,
device=DEVICE)
DRAFT_TOKEN_IDS = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0],
dtype=torch.int64,
device=DEVICE)
NUM_TOKENS = DRAFT_TOKEN_IDS.shape[0]
DRAFT_PROBS = None
TARGET_PROBS = torch.tensor(
[
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
[0.1, 0.9, 0.0, 0.0, 0.0], # 1
[0.2, 0.1, 0.2, 0.4, 0.1], # 0
[0.1, 0.4, 0.1, 0.1, 0.3], # 0
[0.2, 0.1, 0.4, 0.1, 0.2], # 0
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
[0.4, 0.4, 0.1, 0.0, 0.1], # 1
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
[0.4, 0.0, 0.5, 0.0, 0.1], # 1
[0.4, 0.1, 0.3, 0.1, 0.1], # 1
],
dtype=torch.float32,
device=DEVICE)
UNIFORM_PROBS = torch.tensor([
0.9,
0.0,
0.9,
0.7,
0.8,
0.5,
0.45,
1.0,
0.5,
0.45,
1.0,
0.39,
0.4,
0.1,
0.3,
],
dtype=torch.float32,
device=DEVICE)
BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ),
MAX_SPEC_LEN + 1,
dtype=torch.int64,
device=DEVICE)
RECOVERED_TOKEN_IDS = torch.full((NUM_TOKENS, ),
MAX_SPEC_LEN,
dtype=torch.int64,
device=DEVICE)
IS_GREEDY = torch.zeros(BATCH_SIZE, dtype=torch.bool, device=DEVICE)
IS_GREEDY[4] = True


@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS])
@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS])
@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS])
@pytest.mark.parametrize("target_probs", [TARGET_PROBS])
@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS])
@pytest.mark.parametrize("recovered_token_ids", [RECOVERED_TOKEN_IDS])
@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS])
@pytest.mark.parametrize("is_greedy", [IS_GREEDY])
@pytest.mark.parametrize("batch_size", [BATCH_SIZE])
@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN])
@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE])
@torch.inference_mode()
def test_rejection_sampler_block_verify_triton_kernel(
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
draft_probs, # [num_tokens, vocab_size] or None
target_probs, # [num_tokens, vocab_size]
bonus_token_ids, # [batch_size]
recovered_token_ids, # [num_tokens]
uniform_probs, # [num_tokens]
is_greedy, # [batch_size]
batch_size, # int
max_spec_len, # int
vocab_size, # int
) -> None:

grid, block_size = cal_grid_and_block_size(batch_size)

output_token_ids_ref = torch.full((batch_size, max_spec_len + 1),
-1,
dtype=torch.int64,
device=DEVICE)

output_token_ids_triton = output_token_ids_ref.clone()

rejection_random_sample_block_verify_pytorch(
output_token_ids=output_token_ids_ref,
cu_num_draft_tokens=cu_num_draft_tokens,
draft_token_ids=draft_token_ids,
draft_probs=draft_probs,
target_probs=target_probs,
bonus_token_ids=bonus_token_ids,
recovered_token_ids=recovered_token_ids,
uniform_probs=uniform_probs,
is_greedy=is_greedy,
max_spec_len=max_spec_len,
vocab_size=vocab_size,
IS_NGRAM=draft_probs is None)

rejection_random_sample_block_verify_kernel[(grid, )](
output_token_ids_ptr=output_token_ids_triton,
cu_num_draft_tokens_ptr=cu_num_draft_tokens,
draft_token_ids_ptr=draft_token_ids,
draft_probs_ptr=draft_probs,
target_probs_ptr=target_probs,
bonus_token_ids_ptr=bonus_token_ids,
recovered_token_ids_ptr=recovered_token_ids,
uniform_probs_ptr=uniform_probs,
is_greedy_ptr=is_greedy,
max_spec_len=max_spec_len,
vocab_size=vocab_size,
vec_len=batch_size,
NO_DRAFT_PROBS=draft_probs is None,
BLOCK_SIZE=block_size)
torch.npu.synchronize()
assert torch.equal(output_token_ids_ref, output_token_ids_triton)
81 changes: 81 additions & 0 deletions vllm_ascend/ops/triton/reject_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,84 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
BLOCK_SIZE=block_size,
)


@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_block_verify_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < vec_len
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
not_greedy_mask = is_greedy == 0
start_idxs = tl.where(
offsets == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
if not_greedy:

rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
start_idx = tl.get_element(start_idxs, (req_i, ))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))

for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
uniform_prob = uniform_prob * tmp_uniform_prob

if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)

pi = min(pi * target_prob / draft_prob, 1.0)
if draft_prob > 0 and pi >= uniform_prob:
last_accepted_token_pos = pos
rejected = False
else:
rejected = True

if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
pos, token_id)

if rejected:
recovered_token_id = tl.load(recovered_token_ids_ptr +
start_idx +
last_accepted_token_pos + 1)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
last_accepted_token_pos + 1, recovered_token_id)
else:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
Loading
Loading