Skip to content

[Platform] Add current_platform.num_compute_units interface#35042

Merged
vllm-bot merged 10 commits intovllm-project:mainfrom
jikunshang:kunshang/num_sms
Feb 25, 2026
Merged

[Platform] Add current_platform.num_compute_units interface#35042
vllm-bot merged 10 commits intovllm-project:mainfrom
jikunshang:kunshang/num_sms

Conversation

@jikunshang
Copy link
Copy Markdown
Collaborator

@jikunshang jikunshang commented Feb 22, 2026

Purpose

there are some torch.cuda.get_device_properties().multi_processor_count across vllm code base. we can unify it into current_platform.num_compute_units interface to make it clean and extensible for non-cuda hardware like xpu and npu.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@dosubot
Copy link
Copy Markdown

dosubot bot commented Feb 22, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

@mergify mergify bot added performance Performance-related issues nvidia rocm Related to AMD ROCm v1 labels Feb 22, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new get_num_sm interface to abstract the retrieval of SM (Streaming Multiprocessor) counts across different platforms (CUDA, ROCm, XPU). This change unifies the way SM counts are obtained, making the codebase cleaner and more extensible for non-CUDA hardware. The existing torch.cuda.get_device_properties().multi_processor_count calls are replaced with the new platform-agnostic interface. This is a good step towards improving platform independence and maintainability.

@njhill
Copy link
Copy Markdown
Member

njhill commented Feb 22, 2026

Thanks @jikunshang! LGTM but wonder if we should name it something more generic since SM is a CUDA term AFAIK

Maybe get_multi_processor_count?

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 22, 2026

Hi @jikunshang, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@jikunshang
Copy link
Copy Markdown
Collaborator Author

@njhill seems Compute Unit is unified term across GPU vendor https://chatgpt.com/share/699a69bd-1fd0-8002-aeb8-e12bf149533e
I prefer to rename to get_num_compute_units, thoughts?

@jikunshang jikunshang changed the title [Platform]Add current_platform.get_num_sm interface [Platform]Add current_platform.get_num_compute_units interface Feb 22, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 22, 2026

Hi @jikunshang, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to rename to get_num_compute_units, thoughts?

How about just compute_unit_count, since there is already device_count?

Or if not then I think just num_compute_units would be better

if allspark_supported:
properties = torch.cuda.get_device_properties(b.device.index)
sm_count = properties.multi_processor_count
sm_count = current_platform.get_num_compute_units(b.device.index)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably doesn't make sense to change in places like this which are cuda-specific anyway

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to always use this current_platfrom API here. we should avoid/reduce using torch.cuda APIs since we are proposing this RFC(#30679). (though non-cuda platform will never fall into current code path)
my understanding here is it will use property to check cuda capability later. I feel we can also refactor this part into something like current_platform.supported_arch

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just think in places like this we are already using torch.cuda.get_device_properties (line above) so it's better to keep the code consistent (like you say, it can always be refactored for portability in future if/when appropriate).

Similarly if it's in cuda or rocm or xpu-specific files/code then I don't think there's a need to use the current_platform version, and actually maybe better not to since it implies that code could be cross-platform which could be misleading.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. reverted:)

@jikunshang
Copy link
Copy Markdown
Collaborator Author

I prefer to rename to get_num_compute_units, thoughts?

How about just compute_unit_count, since there is already device_count?

Or if not then I think just num_compute_units would be better

I feel count is for something countable and you can control use it or not.. while we always want to use all compute units. so let's use num_compute_units :)

@njhill
Copy link
Copy Markdown
Member

njhill commented Feb 23, 2026

Also remove platform_utils.get_cu_count() method now?

@jikunshang
Copy link
Copy Markdown
Collaborator Author

Also remove platform_utils.get_cu_count() method now?

nice catch, done!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Feb 25, 2026
@njhill njhill changed the title [Platform]Add current_platform.num_compute_units interface [Platform] Add current_platform.num_compute_units interface Feb 25, 2026
@njhill njhill enabled auto-merge (squash) February 25, 2026 01:49
@vllm-bot vllm-bot merged commit 8ad54a9 into vllm-project:main Feb 25, 2026
71 of 74 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 25, 2026
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 25, 2026
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

did this PR break Distributed Tests (2 GPus)?

LucasWilkinson added a commit to neuralmagic/vllm that referenced this pull request Feb 25, 2026
torch.cuda.current_device() returns an int directly, not a device
object with an .index attribute. This was introduced in vllm-project#35042.

Co-Authored-By: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@jikunshang
Copy link
Copy Markdown
Collaborator Author

did this PR break Distributed Tests (2 GPus)?

oh yes. sorry for that and thanks @LucasWilkinson for fixing

haanjack pushed a commit to haanjack/vllm that referenced this pull request Feb 26, 2026
…ject#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
tom-zju pushed a commit to tom-zju/vllm that referenced this pull request Feb 26, 2026
…ject#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
flutist pushed a commit to flutist/vllm_custom_dataset_img_support_base64 that referenced this pull request Feb 28, 2026
…ject#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: xjx <493337577@qq.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…ject#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…ject#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…ject#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…ject#35042)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
AndreasKaratzas added a commit to ROCm/vllm that referenced this pull request Mar 21, 2026
DarkLight1337 pushed a commit that referenced this pull request Mar 22, 2026
…37764)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Mar 23, 2026
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Mar 27, 2026
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…ect#35042 (vllm-project#37764)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…ect#35042 (vllm-project#37764)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants