-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[BugFix][ROCm] Fix get_cu_count missing variable error
#28608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: ganyi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a TypeError in the get_cu_count function by removing an unnecessary cls parameter. I've also identified a potential improvement within the same function. The current implementation could lead to unintended CUDA context initialization. I've provided a suggestion to use an existing utility function from the same file to make it safer in multiprocessing environments.
| def get_cu_count(device_id: int = 0) -> int: | ||
| """Returns the total number of compute units (CU) on single GPU.""" | ||
| return torch.cuda.get_device_properties(device_id).multi_processor_count |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While removing the unused cls parameter is correct, the function can be further improved. The direct call to torch.cuda.get_device_properties will initialize the CUDA context, which can cause issues in multiprocessing environments. This file provides a safer utility, cuda_get_device_properties, which avoids this side effect. Using it here would make the function more robust.
| def get_cu_count(device_id: int = 0) -> int: | |
| """Returns the total number of compute units (CU) on single GPU.""" | |
| return torch.cuda.get_device_properties(device_id).multi_processor_count | |
| def get_cu_count(device_id: int = 0) -> int: | |
| """Returns the total number of compute units (CU) on single GPU.""" | |
| return cuda_get_device_properties(device_id, ("multi_processor_count",))[0] |
|
@tjtanaa @wangxiyuan @gshtras @HAIAI please take a look |
get_cu_count missing variable errorget_cu_count missing variable error
tjtanaa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix.
wangxiyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, my silly mistake
haha, that's fine, I spot it in time! |
…t#28608) Signed-off-by: ganyi <[email protected]> Signed-off-by: George D. Torres <[email protected]>
…t#28608) Signed-off-by: ganyi <[email protected]> Signed-off-by: Bram Wasti <[email protected]>
Purpose
Got error message when PR #27005 merged
The
get_cu_countcalling missing class variable, this PR remove that class variable to make sure the functionality on ROCm platform.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.