diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py index c8a857490b3..95c1157abe9 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py @@ -4,8 +4,11 @@ rejection_random_sample_kernel as original_rejection_random_sample_kernel from vllm_ascend.ops.triton.reject_sample import ( - cal_grid_and_block_size, rejection_random_sample_kernel) + cal_grid_and_block_size, rejection_random_sample_block_verify_kernel, + rejection_random_sample_kernel) from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton +from vllm_ascend.sample.rejection_sampler import \ + rejection_random_sample_block_verify_pytorch @pytest.fixture(scope="function", autouse=True) @@ -93,3 +96,134 @@ def test_rejection_random_sample(max_spec_len, vocab_size, batch_size): BLOCK_SIZE=block_size) torch.npu.synchronize() assert torch.equal(original_output_token_ids, output_token_ids) + + +DEVICE = "npu" +BATCH_SIZE = 7 +MAX_SPEC_LEN = 3 +VOCAB_SIZE = 5 +CU_NUM_DRAFT_TOKENS = torch.tensor([2, 2, 5, 8, 11, 14, 15], + dtype=torch.int32, + device=DEVICE) +DRAFT_TOKEN_IDS = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0], + dtype=torch.int64, + device=DEVICE) +NUM_TOKENS = DRAFT_TOKEN_IDS.shape[0] +DRAFT_PROBS = None +TARGET_PROBS = torch.tensor( + [ + [0.4, 0.3, 0.1, 0.1, 0.1], # 0 + [0.1, 0.9, 0.0, 0.0, 0.0], # 1 + [0.2, 0.1, 0.2, 0.4, 0.1], # 0 + [0.1, 0.4, 0.1, 0.1, 0.3], # 0 + [0.2, 0.1, 0.4, 0.1, 0.2], # 0 + [0.4, 0.2, 0.1, 0.2, 0.1], # 0 + [0.1, 0.6, 0.1, 0.1, 0.1], # 1 + [0.2, 0.2, 0.2, 0.3, 0.1], # 0 + [0.4, 0.2, 0.1, 0.2, 0.1], # 0 + [0.1, 0.6, 0.1, 0.1, 0.1], # 1 + [0.2, 0.2, 0.2, 0.3, 0.1], # 0 + [0.4, 0.4, 0.1, 0.0, 0.1], # 1 + [0.4, 0.3, 0.1, 0.1, 0.1], # 0 + [0.4, 0.0, 0.5, 0.0, 0.1], # 1 + [0.4, 0.1, 0.3, 0.1, 0.1], # 1 + ], + dtype=torch.float32, + device=DEVICE) +UNIFORM_PROBS = torch.tensor([ + 0.9, + 0.0, + 0.9, + 0.7, + 0.8, + 0.5, + 0.45, + 1.0, + 0.5, + 0.45, + 1.0, + 0.39, + 0.4, + 0.1, + 0.3, +], + dtype=torch.float32, + device=DEVICE) +BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ), + MAX_SPEC_LEN + 1, + dtype=torch.int64, + device=DEVICE) +RECOVERED_TOKEN_IDS = torch.full((NUM_TOKENS, ), + MAX_SPEC_LEN, + dtype=torch.int64, + device=DEVICE) +IS_GREEDY = torch.zeros(BATCH_SIZE, dtype=torch.bool, device=DEVICE) +IS_GREEDY[4] = True + + +@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS]) +@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS]) +@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS]) +@pytest.mark.parametrize("target_probs", [TARGET_PROBS]) +@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS]) +@pytest.mark.parametrize("recovered_token_ids", [RECOVERED_TOKEN_IDS]) +@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS]) +@pytest.mark.parametrize("is_greedy", [IS_GREEDY]) +@pytest.mark.parametrize("batch_size", [BATCH_SIZE]) +@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN]) +@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE]) +@torch.inference_mode() +def test_rejection_sampler_block_verify_triton_kernel( + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + recovered_token_ids, # [num_tokens] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + batch_size, # int + max_spec_len, # int + vocab_size, # int +) -> None: + + grid, block_size = cal_grid_and_block_size(batch_size) + + output_token_ids_ref = torch.full((batch_size, max_spec_len + 1), + -1, + dtype=torch.int64, + device=DEVICE) + + output_token_ids_triton = output_token_ids_ref.clone() + + rejection_random_sample_block_verify_pytorch( + output_token_ids=output_token_ids_ref, + cu_num_draft_tokens=cu_num_draft_tokens, + draft_token_ids=draft_token_ids, + draft_probs=draft_probs, + target_probs=target_probs, + bonus_token_ids=bonus_token_ids, + recovered_token_ids=recovered_token_ids, + uniform_probs=uniform_probs, + is_greedy=is_greedy, + max_spec_len=max_spec_len, + vocab_size=vocab_size, + IS_NGRAM=draft_probs is None) + + rejection_random_sample_block_verify_kernel[(grid, )]( + output_token_ids_ptr=output_token_ids_triton, + cu_num_draft_tokens_ptr=cu_num_draft_tokens, + draft_token_ids_ptr=draft_token_ids, + draft_probs_ptr=draft_probs, + target_probs_ptr=target_probs, + bonus_token_ids_ptr=bonus_token_ids, + recovered_token_ids_ptr=recovered_token_ids, + uniform_probs_ptr=uniform_probs, + is_greedy_ptr=is_greedy, + max_spec_len=max_spec_len, + vocab_size=vocab_size, + vec_len=batch_size, + NO_DRAFT_PROBS=draft_probs is None, + BLOCK_SIZE=block_size) + torch.npu.synchronize() + assert torch.equal(output_token_ids_ref, output_token_ids_triton) diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 6de1ae64e79..142815572ea 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -378,3 +378,84 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation. BLOCK_SIZE=block_size, ) + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_block_verify_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + vec_len, + NO_DRAFT_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + block_idx = tl.program_id(0) + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < vec_len + is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) + not_greedy_mask = is_greedy == 0 + start_idxs = tl.where( + offsets == 0, 0, + tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) + n_num_draft_tokens = end_idxs - start_idxs + for req_i in range(BLOCK_SIZE): + not_greedy = tl.get_element(not_greedy_mask, (req_i, )) + if not_greedy: + + rejected = False + pi = 1.0 + uniform_prob = 1.0 + last_accepted_token_pos = -1 + start_idx = tl.get_element(start_idxs, (req_i, )) + req_idx = block_idx * BLOCK_SIZE + req_i + num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, )) + + for pos in range(num_draft_tokens): + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + uniform_prob = uniform_prob * tmp_uniform_prob + + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + + pi = min(pi * target_prob / draft_prob, 1.0) + if draft_prob > 0 and pi >= uniform_prob: + last_accepted_token_pos = pos + rejected = False + else: + rejected = True + + if last_accepted_token_pos > -1: + for pos in range(last_accepted_token_pos + 1): + token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + pos, token_id) + + if rejected: + recovered_token_id = tl.load(recovered_token_ids_ptr + + start_idx + + last_accepted_token_pos + 1) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + last_accepted_token_pos + 1, recovered_token_id) + else: + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 27adab6ae87..02159c670d8 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -9,8 +9,9 @@ from vllm_ascend.ops.triton.reject_sample import ( cal_grid_and_block_size, expand_triton, - rejection_greedy_sample_with_triton, rejection_random_sample_kernel, - sample_recovered_tokens_kernel) + rejection_greedy_sample_with_triton, + rejection_random_sample_block_verify_kernel, + rejection_random_sample_kernel, sample_recovered_tokens_kernel) from vllm_ascend.sample.sampler import apply_top_k_top_p PLACEHOLDER_TOKEN_ID = -1 @@ -108,6 +109,9 @@ def rejection_sample( assert bonus_token_ids.is_contiguous() assert target_probs.shape == (num_tokens, vocab_size) + # When num_speculative_tokens>=3, using block verify. + using_block_verify = max_spec_len >= 3 + # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), @@ -176,41 +180,74 @@ def rejection_sample( sampling_metadata, device, ) - - # Rejection sampling for random sampling requests. - if HAS_TRITON: - rejection_random_sample_kernel[(grid, )]( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs.to(torch.float32), - is_greedy, - max_spec_len, - vocab_size, - batch_size, - NO_DRAFT_PROBS=draft_probs is None, - BLOCK_SIZE=block_size, - ) + if not using_block_verify: + # Rejection sampling for random sampling requests. + if HAS_TRITON: + rejection_random_sample_kernel[(grid, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + batch_size, + NO_DRAFT_PROBS=draft_probs is None, + BLOCK_SIZE=block_size, + ) + else: + rejection_random_sample_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + # num_warps=1, + ) else: - rejection_random_sample_pytorch( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs is None, - # num_warps=1, - ) + # MagicMTP: Improving acceptance rate with Block Verify. + if HAS_TRITON: + rejection_random_sample_block_verify_kernel[(grid, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs.to(torch.float32), + is_greedy, + max_spec_len, + vocab_size, + batch_size, + NO_DRAFT_PROBS=draft_probs is None, + BLOCK_SIZE=block_size, + ) + else: + rejection_random_sample_block_verify_pytorch(output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs + is None) return output_token_ids @@ -680,3 +717,86 @@ def sample_recovered_tokens_pytorch( recovered_ids = torch.argmax(prob_over_q, dim=1) output_token_ids[:] = recovered_ids + + +def rejection_random_sample_block_verify_pytorch( + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + draft_probs, # [num_tokens, vocab_size] or None + target_probs, # [num_tokens, vocab_size] + bonus_token_ids, # [batch_size] + recovered_token_ids, # [num_tokens] + uniform_probs, # [num_tokens] + is_greedy, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM=False, +): + batch_size = output_token_ids.shape[0] + device = output_token_ids.device + + zero_cpu = torch.tensor([0], pin_memory=True) + zero_device = zero_cpu.to(device, non_blocking=True) + + cu_start = torch.cat([zero_device, cu_num_draft_tokens[:-1]]) + cu_end = cu_num_draft_tokens + num_draft_per_batch = (cu_end - cu_start)[:, None] + pos_indices_cpu = torch.arange(max_spec_len, pin_memory=True) + pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :] + valid_mask = pos_indices < num_draft_per_batch + global_token_indices = cu_start[:, None] + pos_indices + global_token_indices = global_token_indices.clamp( + 0, draft_token_ids.shape[0] - 1) + draft_tokens = draft_token_ids[global_token_indices] + + if IS_NGRAM: + ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) + draft_token_probs = ones_cpu.to( + device, non_blocking=True).expand_as(draft_tokens) + else: + flat_indices = global_token_indices.flatten() + flat_draft_tokens = draft_tokens.flatten() + flat_draft_probs = draft_probs[flat_indices, flat_draft_tokens] + draft_token_probs = flat_draft_probs.view(batch_size, max_spec_len) + + flat_indices = global_token_indices.flatten() + flat_draft_tokens = draft_tokens.flatten() + flat_target_probs = target_probs[flat_indices, flat_draft_tokens] + target_token_probs = flat_target_probs.view(batch_size, max_spec_len) + uniform_token_probs = uniform_probs[global_token_indices] + recovered_tokens = recovered_token_ids[global_token_indices] + + pi = target_token_probs / draft_token_probs + pi = pi.clamp(max=1.0) + pi = torch.cumprod(pi, dim=-1) + uniform_token_probs = torch.cumprod(uniform_token_probs, dim=-1) + legal_mask = (draft_token_probs > 0) & (pi >= uniform_token_probs) + legal_mask = legal_mask & valid_mask + + last_accept_pos = torch.where( + legal_mask.any(dim=-1, keepdim=True), + (max_spec_len - + legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1), + -1) + non_greedy_mask = (~is_greedy)[:, None] + + accept_mask = (pos_indices + <= last_accept_pos) & valid_mask & non_greedy_mask + output_token_ids[:, :max_spec_len] = torch.where( + accept_mask, draft_tokens, output_token_ids[:, :max_spec_len]) + + reject_mask = (pos_indices + == last_accept_pos + 1) & valid_mask & non_greedy_mask + output_token_ids[:, :max_spec_len] = torch.where( + reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len]) + + bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask + all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True) + all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] + bonus_pos_match = (all_positions == num_draft_per_batch) + bonus_mask = bonus_mask & bonus_pos_match + bonus_values_expanded = bonus_token_ids.view(-1, 1).expand( + -1, max_spec_len + 1) + output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, + output_token_ids)