Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import pytest
from vllm.triton_utils import triton
from vllm_ascend.worker.v2.sample.logprob import _topk_log_softmax_kernel


@pytest.mark.parametrize("batch_size,vocab_size,num_logprobs", [
(48, 102400, 50),
(96, 102400, 1),
(24, 151936, 8),
])
def test_topk_log_softmax_kernel(batch_size, vocab_size, num_logprobs):
"""Test _topk_log_softmax_kernel for computing log probabilities
Args:
batch_size: Number of sequences in the batch
vocab_size: Size of the vocabulary
num_logprobs: Number of tokens to compute log probabilities for
"""
# ========== Setup test data ==========
torch.manual_seed(42)

# Generate random logits
logits = torch.randn(batch_size, vocab_size, device='npu', dtype=torch.float32)

# Generate token_ids for which to compute logprobs
token_ids = torch.randint(0, vocab_size, (batch_size, num_logprobs),
device='npu', dtype=torch.int64)

# ========== Execute test ==========
# Prepare output tensor
triton_output = torch.empty(
batch_size, num_logprobs,
dtype=torch.float32,
device='npu'
)

# Invoke Triton kernel
_topk_log_softmax_kernel[(batch_size,)](
triton_output,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024,
PADDED_TOPK=max(triton.next_power_of_2(num_logprobs), 2),
)
torch.npu.synchronize()

# Compute reference values using PyTorch
torch_logprobs = torch.log_softmax(logits, dim=-1)

# Extract logprobs for each batch and token_id
ref_output = torch.zeros_like(triton_output)
for i in range(batch_size):
for j in range(num_logprobs):
token_id = token_ids[i, j]
ref_output[i, j] = torch_logprobs[i, token_id]
Comment on lines +54 to +58
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reference output calculation uses nested Python loops, which is inefficient and less readable than a vectorized approach. You can achieve the same result more efficiently and concisely by using torch.gather.

Suggested change
ref_output = torch.zeros_like(triton_output)
for i in range(batch_size):
for j in range(num_logprobs):
token_id = token_ids[i, j]
ref_output[i, j] = torch_logprobs[i, token_id]
ref_output = torch.gather(torch_logprobs, 1, token_ids)


# ========== Verify results ==========
assert torch.allclose(triton_output, ref_output, rtol=1e-3, atol=1e-3), \
f"Triton output differs from PyTorch reference.\n" \
f"Max diff: {torch.max(torch.abs(triton_output - ref_output))}\n" \
f"Mean diff: {torch.mean(torch.abs(triton_output - ref_output))}"
85 changes: 85 additions & 0 deletions vllm_ascend/worker/v2/sample/logprob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/logprob.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.

import torch
from vllm.triton_utils import tl, triton


@triton.jit
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride

max_val = float("-inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore

se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
# NOTE(wangx700): tl.where does not support int64 so we cast it to float32.
block = block.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments to explain the difference between vllm and vllm-ascend.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,I added some notes to explain the difference between vllm and vllm-ascend.

se += tl.sum(e)
Comment on lines +46 to +55
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop for calculating the sum of exponentials (se) can be made more efficient and readable. Currently, it loads values with other=0.0, computes exp, and then uses tl.where to mask out-of-bound elements. This performs unnecessary computations.

A better approach is to load with other=float("-inf"). This way, tl.exp on out-of-bound elements will naturally result in 0.0, removing the need for an explicit tl.where call. Additionally, the type conversion block.to(tl.float32) is unnecessary as block is only used for masking and pointer arithmetic.

Suggested change
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
block = block.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
se += tl.sum(e)

lse = tl.log(se)

k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)

logits = tl.load(row_ptr + topk_ids, mask=k_mask)
logits = logits.to(tl.float32)
o = logits - max_val - lse
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)


def compute_token_logprobs(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor:
batch_size, vocab_size = logits.shape
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
logits.stride(0),
token_ids,
num_logprobs,
vocab_size,
BLOCK_SIZE=1024, # type: ignore
# NOTE(wangx700): PADDED_TOPK must be at least 2 to avoid
# num_logprobs=1 getting wrong results.
PADDED_TOPK=max(triton.next_power_of_2(num_logprobs), 2),
)
return logprobs
Loading