-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[ROCm][BugFix] Remove the usage of device_info from aiter
#28383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: ganyi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a ModuleNotFoundError by removing a dependency on aiter.ops.triton.utils.device_info. However, the removal of the get_num_sms() call and its replacement with total_tokens for NUM_PRGMS might introduce a performance regression. The original logic was likely intended to optimize the number of launched Triton programs based on the available hardware compute units. I've provided a suggestion to restore this optimization using vLLM's platform abstraction layer, which should resolve the import error while preserving performance.
| num_heads = key_cache.shape[2] | ||
|
|
||
| NUM_PRGMS = num_programs(total_tokens) | ||
| NUM_PRGMS = total_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While removing the dependency on aiter.ops.triton.utils.device_info.get_num_sms fixes the ModuleNotFoundError, changing NUM_PRGMS to total_tokens could lead to a significant performance regression. The original logic min(total_tokens, get_num_sms()) capped the number of Triton programs to the number of streaming multiprocessors (SMs) or compute units (CUs) to optimize execution. By setting NUM_PRGMS = total_tokens, you might be launching an excessive number of programs (e.g., one per token), which can be inefficient.
A better approach would be to use vLLM's platform abstraction to get the number of compute units. You can replace get_num_sms() with current_platform.get_cu_count() to preserve the optimization.
| NUM_PRGMS = total_tokens | |
| NUM_PRGMS = min(total_tokens, current_platform.get_cu_count()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ganyi1996ppo Does this advice help? If it doesn't overall it looks good to me.
It seems gemini suggest correctly. I have double checked the get_sms() from aiter and vLLM's get_cu_count()
they are the same,
VLLM:
Line 454 in d0e186c
| return torch.cuda.get_device_properties(device_id).multi_processor_count |
and AITER:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments, that's better indeed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| NUM_PRGMS = total_tokens | ||
| BLOCK_SIZE = block_size(key_cache, head_dim) | ||
| grid = lambda meta: (NUM_PRGMS,) | ||
| cp_mha_gather_cache_kernel[grid]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid recompiling Triton kernel for every token count
The new NUM_PRGMS = total_tokens value is passed as tl.constexpr, so Triton specializes and caches a separate kernel for every distinct total_tokens encountered. During decoding the token count fluctuates almost every invocation, which now forces a JIT compilation on every call and will quickly thrash the compile cache and slow down inference. The previous code bounded NUM_PRGMS to the device SM count, keeping the number of compiled variants small and stable. Consider clamping NUM_PRGMS to a fixed upper limit (e.g., SMs or another constant) rather than the raw token count to avoid repeated compilations.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove NUM_PRGMS from tl.constexpr
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
tjtanaa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you @ganyi1996ppo
Thanks for the thoughtful review! |
|
@ganyi1996ppo I will monitor for a while as there is this discussion on other two PRs about whether it is appropriate to use |
Signed-off-by: ganyi <[email protected]>
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a critical bug where a ModuleNotFoundError for aiter.ops.triton.utils.device_info was causing crashes for users. The fix correctly removes this problematic dependency. The usage of get_num_sms is replaced with current_platform.get_cu_count(), which is a more robust way to get the number of compute units from within the vLLM framework. Additionally, the logic for calculating num_programs has been improved to be based on total_tokens instead of head_dim, which is more appropriate for token-level parallelization in the Triton kernel. The change from a constexpr to a runtime parameter for the number of programs is also a necessary correctness fix. The changes are well-implemented and resolve the reported issue while also improving the kernel's logic.
DarkLight1337
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is in conflict with #27005. You need to update the import of get_cu_count
…ject#28383) Signed-off-by: ganyi <[email protected]> Signed-off-by: George D. Torres <[email protected]>
…ject#28383) Signed-off-by: ganyi <[email protected]> Signed-off-by: Bram Wasti <[email protected]>
Purpose
Many user report the following error encountered after this PR merged #25763
And we notice many user's aiter doesn't have this module, This PR remove its usage to maintain the backward compatibility to aiter
Test Plan
gsm8k
Test Result
gsm8k result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.