diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py index f6d2385d721..1d108bf4cda 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_penality.py @@ -1,63 +1,47 @@ import pytest import torch -from vllm_ascend.worker.v2.sample.penalties import apply_penalties_and_temperature +from vllm_ascend.worker.v2.sample.penalties import apply_penalties DTYPES = [torch.bfloat16, torch.float16] -NUM_REQS = [2, 4, 8] +NUM_TOKENS = [2, 4, 8] VOCAB_SIZE = [151936] NUM_STATUS = [1, 4, 8, 16] SEEDS = [0] DEVICES = [f"npu:{0}"] +NUM_SPECULATIVE_TOKENS = [0, 1, 3] DEFAULT_ATOL = 1e-3 DEFAULT_RTOL = 1e-3 -class SamplingMetadata: - def __init__(self, - repetition_penalty: torch.Tensor, - frequency_penalty: torch.Tensor, - presence_penalty: torch.Tensor, - temperature: torch.Tensor, - idx_mapping: torch.Tensor, - prompt_bin_mask: torch.Tensor, - output_bin_counts: torch.Tensor): - self.repetition_penalty = repetition_penalty - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.temperature = temperature - self.idx_mapping = idx_mapping - self.prompt_bin_mask = prompt_bin_mask - self.output_bin_counts = output_bin_counts - - -def pytorch_apply_penalties_and_temperature( +def pytorch_apply_penalties( logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + idx_mapping: torch.Tensor, + token_ids: torch.Tensor, + expanded_local_pos: torch.Tensor, + repetition_penalty: torch.Tensor, + frequency_penalty: torch.Tensor, + presence_penalty: torch.Tensor, + prompt_bin_mask: torch.Tensor, + output_bin_counts: torch.Tensor, + num_speculative_tokens: int, ) -> torch.Tensor: """ Pytorch equivalent implementation """ - num_reqs, vocab_size = logits.shape + num_tokens, vocab_size = logits.shape device = logits.device dtype = logits.dtype logits_float = logits.float() - repetition_penalty = sampling_metadata.repetition_penalty - frequency_penalty = sampling_metadata.frequency_penalty - presence_penalty = sampling_metadata.presence_penalty - temperature = sampling_metadata.temperature - idx_mapping = sampling_metadata.idx_mapping - prompt_bin_mask = sampling_metadata.prompt_bin_mask - output_bin_counts = sampling_metadata.output_bin_counts - - temperature = torch.where(temperature == 0.0, torch.ones_like(temperature), temperature) - num_status = prompt_bin_mask.shape[0] num_packed = prompt_bin_mask.shape[1] - prompt_masks_unpacked = torch.zeros(num_status, vocab_size, dtype=torch.bool, device=device) + prompt_masks_unpacked = torch.zeros( + num_status, vocab_size, dtype=torch.bool, + device=device + ) for state_idx in range(num_status): for packed_idx in range(num_packed): @@ -69,82 +53,99 @@ def pytorch_apply_penalties_and_temperature( if (packed_val >> bit_pos) & 1: prompt_masks_unpacked[state_idx, start_idx + bit_pos] = True - for batch_idx in range(num_reqs): - req_state_idx = idx_mapping[batch_idx].item() - - rep_penalty = repetition_penalty[batch_idx].item() - freq_penalty = frequency_penalty[batch_idx].item() - pres_penalty = presence_penalty[batch_idx].item() - temp = temperature[batch_idx].item() + for token_idx in range(num_tokens): + req_state_idx = idx_mapping[token_idx].item() + + rep_penalty = repetition_penalty[req_state_idx].item() + freq_penalty = frequency_penalty[req_state_idx].item() + pres_penalty = presence_penalty[req_state_idx].item() use_rep_penalty = rep_penalty != 1.0 use_freq_penalty = freq_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0 - use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty - use_temperature = temp != 1.0 + use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty - if not (use_penalty or use_temperature): + if not use_penalty: continue current_prompt_mask = prompt_masks_unpacked[req_state_idx] - current_output_counts = output_bin_counts[req_state_idx] - output_bin_mask = current_output_counts > 0 + base_output_counts = output_bin_counts[req_state_idx] + + # Compute cumulative draft counts + pos = expanded_local_pos[token_idx].item() + start_idx_in_batch = token_idx - pos + draft_counts = torch.zeros(vocab_size, device=device, dtype=torch.int32) + + for prev_pos in range(num_speculative_tokens): + if prev_pos < pos: + prev_token = token_ids[start_idx_in_batch + prev_pos + 1].item() + draft_counts[prev_token] += 1 + + # Total counts = base output counts + cumulative draft counts + total_output_counts = base_output_counts + draft_counts + output_bin_mask = total_output_counts > 0 if use_rep_penalty: scale = torch.ones(vocab_size, device=device) mask = current_prompt_mask | output_bin_mask scale[mask] = rep_penalty - pos_mask = logits_float[batch_idx] > 0 + pos_mask = logits_float[token_idx] > 0 scale_factor = torch.where(pos_mask, 1.0 / scale, scale) - logits_float[batch_idx] *= scale_factor + logits_float[token_idx] *= scale_factor if use_freq_penalty: - logits_float[batch_idx] -= freq_penalty * current_output_counts.float() + logits_float[token_idx] -= freq_penalty * total_output_counts.float() if use_pres_penalty: - logits_float[batch_idx] -= pres_penalty * output_bin_mask.float() - - if use_temperature: - logits_float[batch_idx] /= temp + logits_float[token_idx] -= pres_penalty * output_bin_mask.float() return logits_float.to(dtype) def create_test_data( - num_reqs: int = 8, + num_tokens: int = 8, vocab_size: int = 51200, num_status: int = 16, + num_speculative_tokens: int = 3, device: str = "npu", dtype: torch.dtype = torch.bfloat16, seed: int = 42, ): - """Create test data for penalties and temperature""" + """Create test data for penalties""" torch.manual_seed(seed) - logits = torch.randn(num_reqs, vocab_size, device=device, dtype=dtype) + logits = torch.randn(num_tokens, vocab_size, device=device, dtype=dtype) - repetition_penalty = torch.ones(num_reqs, device=device, dtype=torch.float32) - for i in range(num_reqs): + repetition_penalty = torch.ones(num_status, device=device, dtype=torch.float32) + for i in range(num_status): if torch.rand(1) > 0.3: repetition_penalty[i] = torch.rand(1, device=device).item() * 0.8 + 0.6 - frequency_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32) - for i in range(num_reqs): + frequency_penalty = torch.zeros(num_status, device=device, dtype=torch.float32) + for i in range(num_status): if torch.rand(1) > 0.5: frequency_penalty[i] = torch.rand(1, device=device).item() * 0.2 - presence_penalty = torch.zeros(num_reqs, device=device, dtype=torch.float32) - for i in range(num_reqs): + presence_penalty = torch.zeros(num_status, device=device, dtype=torch.float32) + for i in range(num_status): if torch.rand(1) > 0.5: presence_penalty[i] = torch.rand(1, device=device).item() * 0.2 - temperature = torch.ones(num_reqs, device=device, dtype=torch.float32) - for i in range(num_reqs): - if torch.rand(1) > 0.2: - presence_penalty[i] = torch.rand(1, device=device).item() * 1.8 + 0.2 - - idx_mapping = torch.randint(0, num_status, (num_reqs,), device=device, dtype=torch.int32) + idx_mapping = torch.randint( + 0, num_status, (num_tokens,), device=device, + dtype=torch.int32 + ) + + # Create token_ids for speculative decoding + token_ids = torch.randint(0, vocab_size, (num_tokens,), device=device, dtype=torch.int32) + + # Create expanded_local_pos (position within speculative decoding window) + expanded_local_pos = torch.zeros(num_tokens, device=device, dtype=torch.int32) + for i in range(num_tokens): + expanded_local_pos[i] = torch.randint( + 0, num_speculative_tokens + 1, (1,) + ).item() num_packed = (vocab_size + 31) // 32 prompt_bin_mask = torch.zeros(num_status, num_packed, device=device, dtype=torch.int32) @@ -161,44 +162,60 @@ def create_test_data( output_bin_counts = torch.zeros(num_status, vocab_size, device=device, dtype=torch.int32) for state_idx in range(num_status): num_output_tokens = max(1, vocab_size // 20) - output_tokens = torch.randint(0, vocab_size, (num_output_tokens, )) + output_tokens = torch.randint(0, vocab_size, + (num_output_tokens, )) counts = torch.randint(1, 10, (num_output_tokens,)) for token, count in zip(output_tokens, counts): output_bin_counts[state_idx, token] = count - sampling_metadata = SamplingMetadata( - repetition_penalty=repetition_penalty, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - temperature=temperature, - idx_mapping=idx_mapping, - prompt_bin_mask=prompt_bin_mask, - output_bin_counts=output_bin_counts + return ( + logits, + idx_mapping, + token_ids, + expanded_local_pos, + repetition_penalty, + frequency_penalty, + presence_penalty, + prompt_bin_mask, + output_bin_counts, + num_speculative_tokens, ) - return logits, sampling_metadata - -@pytest.mark.parametrize("num_reqs", NUM_REQS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("vocab_size", VOCAB_SIZE) @pytest.mark.parametrize("num_status", NUM_STATUS) +@pytest.mark.parametrize("num_speculative_tokens", NUM_SPECULATIVE_TOKENS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_apply_penalties_and_temperature( - num_reqs, +def test_apply_penalties( + num_tokens, vocab_size, num_status, + num_speculative_tokens, dtype, seed, device ): - logits_triton, sampling_metadata = create_test_data( - num_reqs=num_reqs, + ( + logits_triton, + idx_mapping, + token_ids, + expanded_local_pos, + repetition_penalty, + frequency_penalty, + presence_penalty, + prompt_bin_mask, + output_bin_counts, + num_spec_tokens, + ) = create_test_data( + num_tokens=num_tokens, vocab_size=vocab_size, num_status=num_status, + num_speculative_tokens=num_speculative_tokens, device=device, dtype=dtype, seed=seed @@ -206,14 +223,35 @@ def test_apply_penalties_and_temperature( logits_pytorch = logits_triton.clone() - apply_penalties_and_temperature(logits_triton, sampling_metadata) + apply_penalties( + logits_triton, + idx_mapping, + token_ids, + expanded_local_pos, + repetition_penalty, + frequency_penalty, + presence_penalty, + prompt_bin_mask, + output_bin_counts, + num_spec_tokens, + ) - logits_pytorch_result = pytorch_apply_penalties_and_temperature(logits_pytorch, - sampling_metadata) + logits_pytorch_result = pytorch_apply_penalties( + logits_pytorch, + idx_mapping, + token_ids, + expanded_local_pos, + repetition_penalty, + frequency_penalty, + presence_penalty, + prompt_bin_mask, + output_bin_counts, + num_spec_tokens, + ) atol = DEFAULT_ATOL rtol = DEFAULT_RTOL if dtype == torch.bfloat16: atol = 1e-02 rtol = 1e-02 - assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol) + assert torch.allclose(logits_triton, logits_pytorch_result, atol=atol, rtol=rtol) \ No newline at end of file diff --git a/vllm_ascend/worker/v2/sample/penalties.py b/vllm_ascend/worker/v2/sample/penalties.py index f5100df25d1..8ff307a0177 100644 --- a/vllm_ascend/worker/v2/sample/penalties.py +++ b/vllm_ascend/worker/v2/sample/penalties.py @@ -20,122 +20,141 @@ import torch from vllm.triton_utils import tl, triton -from vllm.v1.sample.metadata import SamplingMetadata @triton.jit -def _penalties_and_temperature_kernel( +def _penalties_kernel( logits_ptr, logits_stride, - repetition_penalty_ptr, - frequency_penalty_ptr, - presence_penalty_ptr, - temperature_ptr, idx_mapping_ptr, + token_ids_ptr, + expanded_local_pos_ptr, + penalties_ptr, + penalties_stride, prompt_bin_mask_ptr, prompt_bin_mask_stride, output_bin_counts_ptr, output_bin_counts_stride, vocab_size, BLOCK_SIZE: tl.constexpr, + INNER_BLOCK_SIZE: tl.constexpr, + MAX_SPEC_LEN: tl.constexpr, ): - batch_idx = tl.program_id(0) - rep_penalty = tl.load(repetition_penalty_ptr + batch_idx) - freq_penalty = tl.load(frequency_penalty_ptr + batch_idx) - pres_penalty = tl.load(presence_penalty_ptr + batch_idx) - temperature = tl.load(temperature_ptr + batch_idx) - temperature = tl.where(temperature == 0.0, 1.0, temperature) + token_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + token_idx) + + # first load penalties once + rep_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 0) + freq_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 1) + pres_penalty = tl.load(penalties_ptr + req_state_idx * penalties_stride + 2) use_rep_penalty = rep_penalty != 1.0 use_freq_penalty = freq_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0 - # NOTE(Ronald1995): vllm original grammar `use_rep_penalty or - # use_freq_penalty or use_pres_penalty`, - # change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`, - # because triton-ascend's compiler doesn't support chained boolean operator. - use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty - use_temperature = temperature != 1.0 - if not (use_penalty or use_temperature): + + # NPU doesn't support chained 'or' operations like 'A or B or C' + use_penalty = use_rep_penalty or use_freq_penalty + use_penalty = use_penalty or use_pres_penalty + + if not use_penalty: # Early return to avoid loading logits. return + bit_masks = tl.full((INNER_BLOCK_SIZE // 32, 32), 1, dtype=tl.int32) << tl.arange(0, 32) block_idx = tl.program_id(1) - block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = block < vocab_size - logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) - logits = logits.to(tl.float32) - - if use_penalty: - req_state_idx = tl.load(idx_mapping_ptr + batch_idx) - output_bin_counts = tl.load( - output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, - mask=mask, + block_start = block_idx * BLOCK_SIZE + + pos = tl.load(expanded_local_pos_ptr + token_idx) + start_idx = token_idx - pos + + inv_rep = 1.0 / rep_penalty + + for inner_offset in tl.static_range(0, BLOCK_SIZE, INNER_BLOCK_SIZE): + inner_block_start = block_start + inner_offset + inner_block = inner_block_start + tl.arange(0, INNER_BLOCK_SIZE) + inner_mask = inner_block < vocab_size + + logits = tl.load(logits_ptr + token_idx * logits_stride + inner_block, mask=inner_mask, other=0.0) + logits = logits.to(tl.float32) + + base_output_counts = tl.load( + output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + inner_block, + mask=inner_mask, + other=0, ) - # to use vector core, if use > 0 will use scalar to slow down performance - output_bin_mask = output_bin_counts != 0 + + # Compute cumulative draft_counts from previous positions in this request + total_counts = base_output_counts.to(tl.int32) + for prev_pos in tl.static_range(MAX_SPEC_LEN): + if prev_pos < pos: + load_idx = start_idx + prev_pos + 1 + prev_token = tl.load(token_ids_ptr + load_idx) + total_counts += inner_block == prev_token + + is_present = total_counts != 0 # Apply repetition penalties. if use_rep_penalty: - packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32) - packed_mask = tl.load( + packed_inner_block_start = inner_block_start // 32 + packed_block = packed_inner_block_start + tl.arange(0, INNER_BLOCK_SIZE // 32) + valid_packed_mask = packed_block < tl.cdiv(vocab_size, 32) + + packed_mask_val = tl.load( prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block, - mask=packed_block < tl.cdiv(vocab_size, 32), + mask=valid_packed_mask, + other=0, ) - # the compiler itself does not optimize right-shift operations, so we change the same func - bit_masks = 1 << tl.arange(0, 32) - bit_masks_expanded = bit_masks[None, :] - - packed_expanded = packed_mask[:, None] - bits_matrix = (packed_expanded & bit_masks_expanded) != 0 + prompt_mask = ((packed_mask_val[:, None] & bit_masks) != 0).reshape(INNER_BLOCK_SIZE) - prompt_bin_mask = bits_matrix.reshape(BLOCK_SIZE) + needs_scaling = prompt_mask | is_present - prompt_bin_mask = prompt_bin_mask.to(tl.int1) - prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) + base_factor = tl.where(logits > 0, inv_rep, rep_penalty) + logits = tl.where(needs_scaling, logits * base_factor, logits) - # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0) - # If logits are positive, divide by penalty, otherwise multiply by penalty. - logits *= tl.where(logits > 0, 1.0 / scale, scale) + freq_term = freq_penalty * total_counts.to(tl.float32) + pres_term = pres_penalty * is_present.to(tl.float32) - # Apply frequency penalties. - logits -= freq_penalty * output_bin_counts - # Apply presence penalties. - logits -= pres_penalty * output_bin_mask + logits = logits - freq_term - pres_term + # Store back to logits. + tl.store(logits_ptr + token_idx * logits_stride + inner_block, logits, mask=inner_mask) - # Apply temperature. - logits = logits / temperature - # Store back to logits. - tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) - - -def apply_penalties_and_temperature( +def apply_penalties( logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + idx_mapping: torch.Tensor, + token_ids: torch.Tensor, + expanded_local_pos: torch.Tensor, + repetition_penalty: torch.Tensor, + frequency_penalty: torch.Tensor, + presence_penalty: torch.Tensor, + prompt_bin_mask: torch.Tensor, + output_bin_counts: torch.Tensor, + num_speculative_tokens: int, ) -> None: - """Override the function because there are some bugs - when _penalties_and_temperature_kernel runs on npu, we need to make some fixes. - you could read NOTE(Ronald1995) comments to understand. - """ - num_reqs, vocab_size = logits.shape - # NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow - # in triton-ascend. - BLOCK_SIZE = 4096 + num_tokens, vocab_size = logits.shape + BLOCK_SIZE = 8192 + INNER_BLOCK_SIZE = 4096 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) - # TODO(Ronald1995): Optimize the performance of the kernel in npu. - _penalties_and_temperature_kernel[(num_reqs, num_blocks)]( + + penalties = torch.stack( + [repetition_penalty[:num_tokens], frequency_penalty[:num_tokens], presence_penalty[:num_tokens]], dim=1 + ).contiguous() + penalties_stride = penalties.stride(0) + + _penalties_kernel[(num_tokens, num_blocks)]( logits, logits.stride(0), - sampling_metadata.repetition_penalty, - sampling_metadata.frequency_penalty, - sampling_metadata.presence_penalty, - sampling_metadata.temperature, - sampling_metadata.idx_mapping, - sampling_metadata.prompt_bin_mask, - sampling_metadata.prompt_bin_mask.stride(0), - sampling_metadata.output_bin_counts, - sampling_metadata.output_bin_counts.stride(0), + idx_mapping, + token_ids, + expanded_local_pos, + penalties, + penalties_stride, + prompt_bin_mask, + prompt_bin_mask.stride(0), + output_bin_counts, + output_bin_counts.stride(0), vocab_size, BLOCK_SIZE=BLOCK_SIZE, + INNER_BLOCK_SIZE=INNER_BLOCK_SIZE, + MAX_SPEC_LEN=num_speculative_tokens, )