-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[model_runner_v2]optimize the performance of the _topk_log_softmax_kernel #7221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
45ac951
af7fb33
e9ee533
36b85d8
e852a16
e0ba987
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import torch | ||
| import pytest | ||
| from vllm.triton_utils import triton | ||
| from vllm_ascend.worker.v2.sample.logprob import _topk_log_softmax_kernel | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batch_size,vocab_size,num_logprobs", [ | ||
| (48, 102400, 50), | ||
| (96, 102400, 1), | ||
| (24, 151936, 8), | ||
| ]) | ||
| def test_topk_log_softmax_kernel(batch_size, vocab_size, num_logprobs): | ||
| """Test _topk_log_softmax_kernel for computing log probabilities | ||
| Args: | ||
| batch_size: Number of sequences in the batch | ||
| vocab_size: Size of the vocabulary | ||
| num_logprobs: Number of tokens to compute log probabilities for | ||
| """ | ||
| # ========== Setup test data ========== | ||
| torch.manual_seed(42) | ||
|
|
||
| # Generate random logits | ||
| logits = torch.randn(batch_size, vocab_size, device='npu', dtype=torch.float32) | ||
|
|
||
| # Generate token_ids for which to compute logprobs | ||
| token_ids = torch.randint(0, vocab_size, (batch_size, num_logprobs), | ||
| device='npu', dtype=torch.int64) | ||
|
|
||
| # ========== Execute test ========== | ||
| # Prepare output tensor | ||
| triton_output = torch.empty( | ||
| batch_size, num_logprobs, | ||
| dtype=torch.float32, | ||
| device='npu' | ||
| ) | ||
|
|
||
| # Invoke Triton kernel | ||
| _topk_log_softmax_kernel[(batch_size,)]( | ||
| triton_output, | ||
| logits, | ||
| logits.stride(0), | ||
| token_ids, | ||
| num_logprobs, | ||
| vocab_size, | ||
| BLOCK_SIZE=1024, | ||
| PADDED_TOPK=max(triton.next_power_of_2(num_logprobs), 2), | ||
| ) | ||
| torch.npu.synchronize() | ||
|
|
||
| # Compute reference values using PyTorch | ||
| torch_logprobs = torch.log_softmax(logits, dim=-1) | ||
|
|
||
| # Extract logprobs for each batch and token_id | ||
| ref_output = torch.zeros_like(triton_output) | ||
| for i in range(batch_size): | ||
| for j in range(num_logprobs): | ||
| token_id = token_ids[i, j] | ||
| ref_output[i, j] = torch_logprobs[i, token_id] | ||
|
|
||
| # ========== Verify results ========== | ||
| assert torch.allclose(triton_output, ref_output, rtol=1e-3, atol=1e-3), \ | ||
| f"Triton output differs from PyTorch reference.\n" \ | ||
| f"Max diff: {torch.max(torch.abs(triton_output - ref_output))}\n" \ | ||
| f"Mean diff: {torch.mean(torch.abs(triton_output - ref_output))}" | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,85 @@ | ||||||||||||||||||||||||||||||||||
| # Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/logprob.py. | ||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||
| # This file is a part of the vllm-ascend project. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||
| from vllm.triton_utils import tl, triton | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @triton.jit | ||||||||||||||||||||||||||||||||||
| def _topk_log_softmax_kernel( | ||||||||||||||||||||||||||||||||||
| output_ptr, | ||||||||||||||||||||||||||||||||||
| logits_ptr, | ||||||||||||||||||||||||||||||||||
| logits_stride, | ||||||||||||||||||||||||||||||||||
| topk_ids_ptr, | ||||||||||||||||||||||||||||||||||
| topk, | ||||||||||||||||||||||||||||||||||
| vocab_size, | ||||||||||||||||||||||||||||||||||
| BLOCK_SIZE: tl.constexpr, | ||||||||||||||||||||||||||||||||||
| PADDED_TOPK: tl.constexpr, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| req_idx = tl.program_id(0) | ||||||||||||||||||||||||||||||||||
| row_ptr = logits_ptr + req_idx * logits_stride | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| max_val = float("-inf") | ||||||||||||||||||||||||||||||||||
| for i in range(0, vocab_size, BLOCK_SIZE): | ||||||||||||||||||||||||||||||||||
| block = i + tl.arange(0, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||
| logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) | ||||||||||||||||||||||||||||||||||
| max_val = tl.max(tl.maximum(logits, max_val)) | ||||||||||||||||||||||||||||||||||
| max_val = max_val.to(tl.float32) # type: ignore | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| se = 0.0 | ||||||||||||||||||||||||||||||||||
| for i in range(0, vocab_size, BLOCK_SIZE): | ||||||||||||||||||||||||||||||||||
| block = i + tl.arange(0, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||
| logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) | ||||||||||||||||||||||||||||||||||
| # NOTE(woosuk): Make sure that logits and all following operations use FP32. | ||||||||||||||||||||||||||||||||||
| logits = logits.to(tl.float32) | ||||||||||||||||||||||||||||||||||
| # NOTE(wangx700): tl.where does not support int64 so we cast it to float32. | ||||||||||||||||||||||||||||||||||
| block = block.to(tl.float32) | ||||||||||||||||||||||||||||||||||
| e = tl.exp(logits - max_val) | ||||||||||||||||||||||||||||||||||
| e = tl.where(block < vocab_size, e, 0.0) | ||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add some comments to explain the difference between vllm and vllm-ascend.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok,I added some notes to explain the difference between vllm and vllm-ascend. |
||||||||||||||||||||||||||||||||||
| se += tl.sum(e) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+46
to
+55
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loop for calculating the sum of exponentials ( A better approach is to load with
Suggested change
|
||||||||||||||||||||||||||||||||||
| lse = tl.log(se) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| k_offset = tl.arange(0, PADDED_TOPK) | ||||||||||||||||||||||||||||||||||
| k_mask = k_offset < topk | ||||||||||||||||||||||||||||||||||
| topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| logits = tl.load(row_ptr + topk_ids, mask=k_mask) | ||||||||||||||||||||||||||||||||||
| logits = logits.to(tl.float32) | ||||||||||||||||||||||||||||||||||
| o = logits - max_val - lse | ||||||||||||||||||||||||||||||||||
| tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def compute_token_logprobs(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||
| batch_size, vocab_size = logits.shape | ||||||||||||||||||||||||||||||||||
| token_ids = token_ids.to(torch.int64) | ||||||||||||||||||||||||||||||||||
| num_logprobs = token_ids.shape[1] | ||||||||||||||||||||||||||||||||||
| logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32) | ||||||||||||||||||||||||||||||||||
| _topk_log_softmax_kernel[(batch_size,)]( | ||||||||||||||||||||||||||||||||||
| logprobs, | ||||||||||||||||||||||||||||||||||
| logits, | ||||||||||||||||||||||||||||||||||
| logits.stride(0), | ||||||||||||||||||||||||||||||||||
| token_ids, | ||||||||||||||||||||||||||||||||||
| num_logprobs, | ||||||||||||||||||||||||||||||||||
| vocab_size, | ||||||||||||||||||||||||||||||||||
| BLOCK_SIZE=1024, # type: ignore | ||||||||||||||||||||||||||||||||||
| # NOTE(wangx700): PADDED_TOPK must be at least 2 to avoid | ||||||||||||||||||||||||||||||||||
| # num_logprobs=1 getting wrong results. | ||||||||||||||||||||||||||||||||||
| PADDED_TOPK=max(triton.next_power_of_2(num_logprobs), 2), | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| return logprobs | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reference output calculation uses nested Python loops, which is inefficient and less readable than a vectorized approach. You can achieve the same result more efficiently and concisely by using
torch.gather.