Skip to content

Conversation

@ganyi1996ppo
Copy link
Contributor

@ganyi1996ppo ganyi1996ppo commented Nov 10, 2025

Purpose

Many user report the following error encountered after this PR merged #25763

from aiter.ops.triton.utils.device_info import get_num_sms
(EngineCore_DP0 pid=230496) ModuleNotFoundError: No module named 'aiter.ops.triton.utils.device_info'

And we notice many user's aiter doesn't have this module, This PR remove its usage to maintain the backward compatibility to aiter

Test Plan

gsm8k

Test Result

gsm8k result

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8203 ± 0.0106
strict-match 5 exact_match 0.8901 ± 0.0086

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Nov 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a ModuleNotFoundError by removing a dependency on aiter.ops.triton.utils.device_info. However, the removal of the get_num_sms() call and its replacement with total_tokens for NUM_PRGMS might introduce a performance regression. The original logic was likely intended to optimize the number of launched Triton programs based on the available hardware compute units. I've provided a suggestion to restore this optimization using vLLM's platform abstraction layer, which should resolve the import error while preserving performance.

num_heads = key_cache.shape[2]

NUM_PRGMS = num_programs(total_tokens)
NUM_PRGMS = total_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While removing the dependency on aiter.ops.triton.utils.device_info.get_num_sms fixes the ModuleNotFoundError, changing NUM_PRGMS to total_tokens could lead to a significant performance regression. The original logic min(total_tokens, get_num_sms()) capped the number of Triton programs to the number of streaming multiprocessors (SMs) or compute units (CUs) to optimize execution. By setting NUM_PRGMS = total_tokens, you might be launching an excessive number of programs (e.g., one per token), which can be inefficient.

A better approach would be to use vLLM's platform abstraction to get the number of compute units. You can replace get_num_sms() with current_platform.get_cu_count() to preserve the optimization.

Suggested change
NUM_PRGMS = total_tokens
NUM_PRGMS = min(total_tokens, current_platform.get_cu_count())

Copy link
Collaborator

@tjtanaa tjtanaa Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganyi1996ppo Does this advice help? If it doesn't overall it looks good to me.

It seems gemini suggest correctly. I have double checked the get_sms() from aiter and vLLM's get_cu_count()

they are the same,

VLLM:

return torch.cuda.get_device_properties(device_id).multi_processor_count

and AITER:

https://github.com/ROCm/aiter/blob/de14bec0ca5a9de94e10f5cad4dc1541ac558689/aiter/ops/triton/utils/device_info.py#L4-L9

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments, that's better indeed!

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 142 to 145
NUM_PRGMS = total_tokens
BLOCK_SIZE = block_size(key_cache, head_dim)
grid = lambda meta: (NUM_PRGMS,)
cp_mha_gather_cache_kernel[grid](

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid recompiling Triton kernel for every token count

The new NUM_PRGMS = total_tokens value is passed as tl.constexpr, so Triton specializes and caches a separate kernel for every distinct total_tokens encountered. During decoding the token count fluctuates almost every invocation, which now forces a JIT compilation on every call and will quickly thrash the compile cache and slow down inference. The previous code bounded NUM_PRGMS to the device SM count, keeping the number of compiled variants small and stable. Consider clamping NUM_PRGMS to a fixed upper limit (e.g., SMs or another constant) rather than the raw token count to avoid repeated compilations.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove NUM_PRGMS from tl.constexpr

Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @ganyi1996ppo

@ganyi1996ppo
Copy link
Contributor Author

LGTM. Thank you @ganyi1996ppo

Thanks for the thoughtful review!

@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 12, 2025

@ganyi1996ppo I will monitor for a while as there is this discussion on other two PRs about whether it is appropriate to use current_platform.get_cu_count(). Could we also get your advice on this question? I have tagged you in one of the discussion thread in #28311 (comment) .

Signed-off-by: ganyi <[email protected]>
@tjtanaa
Copy link
Collaborator

tjtanaa commented Nov 12, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug where a ModuleNotFoundError for aiter.ops.triton.utils.device_info was causing crashes for users. The fix correctly removes this problematic dependency. The usage of get_num_sms is replaced with current_platform.get_cu_count(), which is a more robust way to get the number of compute units from within the vLLM framework. Additionally, the logic for calculating num_programs has been improved to be based on total_tokens instead of head_dim, which is more appropriate for token-level parallelization in the Triton kernel. The change from a constexpr to a runtime parameter for the number of programs is also a necessary correctness fix. The changes are well-implemented and resolve the reported issue while also improving the kernel's logic.

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 12, 2025
@tjtanaa tjtanaa merged commit ca00b1b into vllm-project:main Nov 13, 2025
49 checks passed
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is in conflict with #27005. You need to update the import of get_cu_count

tjtanaa added a commit that referenced this pull request Nov 13, 2025
geodavic pushed a commit to geodavic/vllm that referenced this pull request Nov 16, 2025
bwasti pushed a commit to bwasti/vllm that referenced this pull request Nov 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants