Skip to content

[Bugfix] Fix import gemm_afp4wfp4 failure on AMD#26068

Merged
yeqcharlotte merged 2 commits intovllm-project:mainfrom
zhewenl:fix-quark_w4a4_mxfp4-gating
Oct 3, 2025
Merged

[Bugfix] Fix import gemm_afp4wfp4 failure on AMD#26068
yeqcharlotte merged 2 commits intovllm-project:mainfrom
zhewenl:fix-quark_w4a4_mxfp4-gating

Conversation

@zhewenl
Copy link
Collaborator

@zhewenl zhewenl commented Oct 2, 2025

Purpose

We are seeing these error after #25135 on AMD MI300X:

  File "vllm/v1/engine/async_llm.py", line 231, in from_engine_args
    vllm_config = engine_args.create_engine_config(usage_context)
  File "vllm/engine/arg_utils.py", line 1142, in create_engine_config
    model_config = self.create_model_config()
  File "vllm/engine/arg_utils.py", line 994, in create_model_config
    return ModelConfig(
  File "pydantic/_internal/_dataclasses.py", line 123, in __init__
    s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
  File "vllm/config/model.py", line 648, in __post_init__
    self._verify_quantization()
  File "vllm/config/model.py", line 935, in _verify_quantization
    method = me_quant.get_quantization_config(name)
  File "vllm/model_executor/layers/quantization/__init__.py", line 86, in get_quantization_config
    from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
  File "vllm/model_executor/layers/quantization/quark/quark.py", line 19, in <module>
    from vllm.model_executor.layers.quantization.quark.schemes import (
  File "vllm/model_executor/layers/quantization/quark/schemes/__init__.py", line 5, in <module>
    from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
  File "vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py", line 28, in <module>
    from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
  File "aiter/ops/triton/gemm_afp4wfp4.py", line 40, in <module>
    def _gemm_afp4_wfp4_kernel(
  File "triton/runtime/jit.py", line 852, in jit
    return decorator(fn)
  File "triton/runtime/jit.py", line 840, in decorator
    return JITFunction(
  File "triton/runtime/jit.py", line 667, in __init__
    src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
AttributeError: 'NoneType' object has no attribute 'start'

gemm_afp4wfp4 is a CUDA only triton kernel, so it will crash under JIT-compile with ROCM

Test Plan

@mergify mergify bot added the rocm Related to AMD ROCm label Oct 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash on AMD platforms caused by an attempt to JIT-compile a CUDA-only Triton kernel. The fix, which involves catching the resulting AttributeError, is correct and effectively resolves the immediate issue. In my review, I've also identified a latent bug within the same try-except block that could lead to a NameError under different import failure scenarios. I've provided a detailed explanation of this issue and recommended a refactoring approach to improve the robustness of this platform-specific import handling.

Comment on lines +100 to 101
except (ImportError, AttributeError):
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While adding AttributeError correctly fixes the crash on AMD, this try-except block has a latent bug. It is too broad, covering lines 26 to 98, but the except block only handles dynamic_mxfp4_quant and gemm_afp4wfp4.

If an ImportError or AttributeError occurs during the import of other symbols like shuffle_weight, or the conditional imports of gemm_a4w4 and per_1x32_f4_quant_hip, they will be undefined. This can lead to a NameError later when they are used, for example in process_weights_after_loading or gemm_with_dynamic_quant.

A more robust approach would be to refactor this. A good approach would be to have smaller, more focused try-except blocks for different sets of imports (e.g., common, ROCm-specific, CUDA-specific) and ensure all imported symbols are properly handled in their respective except blocks. This would also require adding checks in QuarkW4A4MXFP4.__init__ to ensure the required kernels are available for the selected execution path.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

@zhewenl zhewenl force-pushed the fix-quark_w4a4_mxfp4-gating branch from a168f7d to 9700c13 Compare October 2, 2025 17:31
@zhewenl zhewenl added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 2, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
@zhewenl zhewenl force-pushed the fix-quark_w4a4_mxfp4-gating branch from 9700c13 to 1fbd9ae Compare October 2, 2025 21:43
@zhewenl
Copy link
Collaborator Author

zhewenl commented Oct 3, 2025

@yeqcharlotte yeqcharlotte merged commit 711f485 into vllm-project:main Oct 3, 2025
53 checks passed
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@zhewenl zhewenl deleted the fix-quark_w4a4_mxfp4-gating branch October 4, 2025 20:43
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
karan pushed a commit to karan/vllm that referenced this pull request Oct 6, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: Karan Goel <3261985+karan@users.noreply.github.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: zhewenli <zhewenli@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants