Skip to content

[ROCm] Enable Triton ScaledMM fallback + kernel selection fix#26668

Merged
ProExpertProg merged 6 commits intovllm-project:mainfrom
shivampr:rocm-triton-fallback
Dec 12, 2025
Merged

[ROCm] Enable Triton ScaledMM fallback + kernel selection fix#26668
ProExpertProg merged 6 commits intovllm-project:mainfrom
shivampr:rocm-triton-fallback

Conversation

@shivampr
Copy link
Contributor

@shivampr shivampr commented Oct 13, 2025

Purpose

Fixes #14397triton_scaled_mm was never used on ROCm due to missing dispatch and checks.
This PR:

  • Enables Triton fallback for ROCm when AITriton is unavailable

  • Adds Triton fallback after CUTLASS on CUDA

  • Implements is_supported() checks for kernel selection

  • Adds a lightweight integration test validating ROCm dispatch logic


Test Plan

1. Mocked test (no GPU)

python3 mini_tests/select_triton_rocm.py

Result

Selected kernel: TritonScaledMMLinearKernel
OK: TritonScaledMMLinearKernel chosen on ROCm fallback.

2. MI300X (ROCm 7.0, vLLM built from this PR)

(a) Triton kernel functional test

max_abs_err≈2.5e-01, max_rel_err≈3.9e-03

(b) OpenAI-compatible API test

python3 -m vllm.entrypoints.openai.api_server \
  --model Qwen/Qwen2.5-0.5B-Instruct --dtype bfloat16 --host 0.0.0.0 --port 8000

Then:

curl http://localhost:8000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{"model":"Qwen/Qwen2.5-0.5B-Instruct","messages":[{"role":"user","content":"Say hi from MI300X."}]}'

Response

"Hello! How can I assist you today?"

Confirms successful end-to-end inference on ROCm.

@mergify mergify bot added the rocm Related to AMD ROCm label Oct 13, 2025
@mergify
Copy link

mergify bot commented Oct 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @shivampr.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue where triton_scaled_mm was not being used on ROCm by fixing the kernel selection logic. It correctly adds TritonScaledMMLinearKernel as a fallback for both ROCm and CUDA, and introduces an is_supported check to ensure kernels are compatible with the current platform. The changes are accompanied by a new integration test to verify the fix.

My review focuses on improving the robustness of the kernel selection. I've suggested making the get_min_capability check in the Triton kernel platform-aware to prevent it from being selected on unsupported ROCm hardware. Additionally, I've pointed out a confusing try-except block in the new test file that should be simplified for clarity and to avoid masking potential errors.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@shivampr shivampr force-pushed the rocm-triton-fallback branch 3 times, most recently from 99018da to 4d3a612 Compare October 13, 2025 05:09
@mergify mergify bot removed the needs-rebase label Oct 13, 2025
@shivampr shivampr force-pushed the rocm-triton-fallback branch 4 times, most recently from d0d088d to 9036316 Compare October 13, 2025 05:50
@shivampr shivampr force-pushed the rocm-triton-fallback branch from 9036316 to 2a6c86c Compare October 24, 2025 05:11
@shivampr shivampr force-pushed the rocm-triton-fallback branch from be28ac6 to d2591bf Compare November 4, 2025 15:07
@shivampr shivampr requested a review from WoosukKwon as a code owner November 4, 2025 15:07
@ProExpertProg
Copy link
Collaborator

Is this ready for review again?

@shivampr
Copy link
Contributor Author

shivampr commented Nov 7, 2025

@ProExpertProg yes!
Sorry will ping you directly from now on if its review ready.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one note about is_supported

@shivampr shivampr requested a review from tjtanaa as a code owner December 10, 2025 05:03
… entry

Signed-off-by: Shivam <shivampr.dev@gmail.com>
Signed-off-by: Shivam <shivamprasad91@gmail.com>
Make is_supported() abstract in base class, remove get_min_capability(),
and implement is_supported() in all kernels. Move platform checks from
can_implement() to is_supported() in AiterScaledMMLinearKernel. Add
CPU-compatible tests for kernel selection validation.

Signed-off-by: Shivam <shivamprasad91@gmail.com>
@shivampr shivampr force-pushed the rocm-triton-fallback branch from 859585f to c8b0b83 Compare December 10, 2025 06:06
@github-project-automation github-project-automation bot moved this to In review in NVIDIA Dec 10, 2025
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 10, 2025
@ProExpertProg
Copy link
Collaborator

@shivampr could you report lm-eval results for a model that uses int8 Triton scaled mm to check this works?

…rd dependency for CI failure

Signed-off-by: Shivam <shivamprasad91@gmail.com>
@mergify mergify bot added the ci/build label Dec 10, 2025
@shivampr
Copy link
Contributor Author

@ProExpertProg

I verified the ROCm int8 Triton ScaledMM path using lm-eval with vLLM.

Environment:
- Base image: RunPod ROCm vLLM 0.9.2 / ROCm 7.0
- Backend: `VLLM_TARGET_DEVICE=rocm`
- Aiter disabled so we exercise Triton fallback:
  ```bash
  export VLLM_TARGET_DEVICE=rocm
  export VLLM_DISABLED_KERNELS="AiterScaledMMLinearKernel"

Command :

lm_eval \
  --model vllm \
  --model_args pretrained="RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",dtype=auto,tensor_parallel_size=1,max_model_len=4096,gpu_memory_utilization=0.9,trust_remote_code=true \
  --tasks gsm8k \
  --num_fewshot 5 \
  --limit 50 \
  --batch_size auto

Form logs :

INFO ... Automatically detected platform rocm.
INFO ... Using TritonScaledMMLinearKernel for CompressedTensorsW8A8Int8

lm-eval results :

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.70 ± 0.0655
strict-match 5 exact_match 0.64 ± 0.0686

Signed-off-by: Shivam <shivamprasad91@gmail.com>
@mergify
Copy link

mergify bot commented Dec 10, 2025

Hi @shivampr, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@ProExpertProg ProExpertProg merged commit cd7740a into vllm-project:main Dec 12, 2025
55 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Dec 12, 2025
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Dec 15, 2025
…roject#26668)

Signed-off-by: Shivam <shivampr.dev@gmail.com>
Signed-off-by: Shivam <shivamprasad91@gmail.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…roject#26668)

Signed-off-by: Shivam <shivampr.dev@gmail.com>
Signed-off-by: Shivam <shivamprasad91@gmail.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…roject#26668)

Signed-off-by: Shivam <shivampr.dev@gmail.com>
Signed-off-by: Shivam <shivamprasad91@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: triton_scaled_mm never used on ROCm

4 participants