Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);

#ifndef USE_ROCM
void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
int64_t num_heads_k, int64_t num_heads_v,
int64_t head_dim, double eps, torch::Tensor& q_weight,
torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
bool is_neox, torch::Tensor& position_ids);
#endif
Comment on lines +95 to +101
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard removes declaration but binding still references op

The new #ifndef USE_ROCM guard in csrc/ops.h (lines 95‑101) removes the declaration of fused_qk_norm_rope on ROCm builds, but csrc/torch_bindings.cpp still unconditionally registers the custom op at lines 178‑184 via ops.impl("fused_qk_norm_rope", torch::kCUDA, &fused_qk_norm_rope);. When building with USE_ROCM defined, the compiler no longer sees any declaration of that symbol before it is used, so torch_bindings.cpp fails to compile on AMD/ROCm even though the function definition still exists in fused_qknorm_rope_kernel.cu. Either the declaration needs to remain available or the binding needs to be wrapped in the same guard; otherwise every ROCm build is broken.

Useful? React with 👍 / 👎.


void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
Expand Down
5 changes: 4 additions & 1 deletion vllm/compilation/fix_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ def __call__(self, graph: torch.fx.Graph):
),
)
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
elif (
current_platform.is_cuda()
and at_target == torch.ops._C.fused_qk_norm_rope.default
):
mutated_args = {1: "qkv"}
args = (
"qkv",
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from .activation_quant_fusion import ActivationQuantFusionPass
from .fusion import RMSNormQuantFusionPass
from .fusion_attn import AttnFusionPass
from .qk_norm_rope_fusion import QKNormRoPEFusionPass

if current_platform.is_cuda():
from .qk_norm_rope_fusion import QKNormRoPEFusionPass
Comment on lines -21 to +22
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While moving this import under is_cuda() is correct to fix the crash on ROCm, it introduces a potential inconsistency. The configuration validation for this fusion in vllm/config/compilation.py still uses is_cuda_alike().

This means a user on a ROCm platform could set enable_qk_norm_rope_fusion=True, and it would pass the configuration check. However, this would lead to a NameError here, as QKNormRoPEFusionPass would not be imported.

To fix this, please also update the check in vllm/config/compilation.py (line 187) to use current_platform.is_cuda() instead of current_platform.is_cuda_alike().

from .collective_fusion import AllReduceFusionPass, AsyncTPPass

from .fix_functionalization import FixFunctionalizationPass
Expand Down
Loading