Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,9 @@ def _supports_current_device() -> bool:
cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
Comment on lines +555 to +557
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change to < (13, 0) enables support for devices with major capability versions 11 and 12. The PR description and updated comment only mention gfx12 (RDNA4). If gfx11 (RDNA3) is not intended to be supported or has not been tested, it would be safer to use a more specific condition to only enable gfx12. This prevents potential issues for users on gfx11 hardware.

If gfx11 is also supported, please consider updating the comment to reflect that.

Suggested change
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
# (9,0) <= cap < (11,0) or cap.major == 12 covers CUDA SM90 (Hopper),
# SM100+ (Blackwell), ROCm gfx942/gfx950 (9.4/9.5), and gfx12 (RDNA4).
return (9, 0) <= (cap.major, cap.minor) < (11, 0) or cap.major == 12


@staticmethod
def _supports_no_act_and_mul() -> bool:
Expand Down Expand Up @@ -884,9 +884,9 @@ def _supports_current_device() -> bool:
cap = p.get_device_capability()
if cap is None:
return False
# (9,0) <= cap < (11,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (which map to 9.4/9.5).
return (9, 0) <= (cap.major, cap.minor) < (11, 0)
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
Comment on lines +887 to +889
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, this change to < (13, 0) also enables support for gfx11. To be safer and only enable the tested gfx12 architecture, a more specific condition is recommended.

Suggested change
# (9,0) <= cap < (13,0) covers CUDA SM90 (Hopper), SM100+ (Blackwell)
# and ROCm gfx942/gfx950 (9.4/9.5) + gfx1200/gfx1201 (12.0).
return (9, 0) <= (cap.major, cap.minor) < (13, 0)
# (9,0) <= cap < (11,0) or cap.major == 12 covers CUDA SM90 (Hopper),
# SM100+ (Blackwell), ROCm gfx942/gfx950 (9.4/9.5), and gfx12 (RDNA4).
return (9, 0) <= (cap.major, cap.minor) < (11, 0) or cap.major == 12


@staticmethod
def _supports_no_act_and_mul() -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def select_mxfp4_moe_backend(
triton_kernels_supported = has_triton_kernels() and (
9,
0,
) <= current_platform.get_device_capability() < (11, 0)
) <= current_platform.get_device_capability() < (13, 0)
Comment on lines 202 to +205
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This condition is now broad enough to enable Triton kernels for gfx11 as well, which is not mentioned in the PR. To avoid potential issues on untested hardware, it's better to make the condition more specific to the architectures that are known to be supported. The suggested change also improves clarity by using a local variable for the device capability.

    cap = current_platform.get_device_capability()
    triton_kernels_supported = has_triton_kernels() and cap and (
        (9, 0) <= cap < (11, 0) or cap.major == 12)


# LoRA: separate experts backend path
if config.is_lora_enabled:
Expand Down
Loading