Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,12 @@

if current_platform.is_rocm():
import aiter
from aiter.ops.triton.utils.device_info import get_num_sms

from vllm.triton_utils import tl, triton

def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))

def num_programs(head_dim):
return min(head_dim, get_num_sms())

@triton.jit
def cp_mha_gather_cache_kernel(
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
Expand Down Expand Up @@ -143,7 +139,7 @@ def cp_mha_gather_cache(
page_size = key_cache.shape[1]
num_heads = key_cache.shape[2]

NUM_PRGMS = num_programs(total_tokens)
NUM_PRGMS = total_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While removing the dependency on aiter.ops.triton.utils.device_info.get_num_sms fixes the ModuleNotFoundError, changing NUM_PRGMS to total_tokens could lead to a significant performance regression. The original logic min(total_tokens, get_num_sms()) capped the number of Triton programs to the number of streaming multiprocessors (SMs) or compute units (CUs) to optimize execution. By setting NUM_PRGMS = total_tokens, you might be launching an excessive number of programs (e.g., one per token), which can be inefficient.

A better approach would be to use vLLM's platform abstraction to get the number of compute units. You can replace get_num_sms() with current_platform.get_cu_count() to preserve the optimization.

Suggested change
NUM_PRGMS = total_tokens
NUM_PRGMS = min(total_tokens, current_platform.get_cu_count())

Copy link
Collaborator

@tjtanaa tjtanaa Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganyi1996ppo Does this advice help? If it doesn't overall it looks good to me.

It seems gemini suggest correctly. I have double checked the get_sms() from aiter and vLLM's get_cu_count()

they are the same,

VLLM:

return torch.cuda.get_device_properties(device_id).multi_processor_count

and AITER:

https://github.com/ROCm/aiter/blob/de14bec0ca5a9de94e10f5cad4dc1541ac558689/aiter/ops/triton/utils/device_info.py#L4-L9

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments, that's better indeed!

BLOCK_SIZE = block_size(key_cache, head_dim)
grid = lambda meta: (NUM_PRGMS,)
cp_mha_gather_cache_kernel[grid](

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid recompiling Triton kernel for every token count

The new NUM_PRGMS = total_tokens value is passed as tl.constexpr, so Triton specializes and caches a separate kernel for every distinct total_tokens encountered. During decoding the token count fluctuates almost every invocation, which now forces a JIT compilation on every call and will quickly thrash the compile cache and slow down inference. The previous code bounded NUM_PRGMS to the device SM count, keeping the number of compiled variants small and stable. Consider clamping NUM_PRGMS to a fixed upper limit (e.g., SMs or another constant) rather than the raw token count to avoid repeated compilations.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove NUM_PRGMS from tl.constexpr

Expand Down
Loading