Skip to content

[Diffusion] safe fallback for fused QK-norm#16329

Closed
RubiaCx wants to merge 2 commits intosgl-project:mainfrom
RubiaCx:qknorm-fallback
Closed

[Diffusion] safe fallback for fused QK-norm#16329
RubiaCx wants to merge 2 commits intosgl-project:mainfrom
RubiaCx:qknorm-fallback

Conversation

@RubiaCx
Copy link
Copy Markdown
Collaborator

@RubiaCx RubiaCx commented Jan 3, 2026

Motivation

The fused QK-norm kernel used here was introduced in #16062, and this PR ensures it does not break on non-applicable layouts or ROCm (AMD CI) while keeping QK-norm behavior consistent when enabled.
1272abe947e43703ca5e5f99c10c5e47

Modifications

Remove _is_cuda() gating at callsites; always call apply_qk_norm when self.qk_norm is enabled.

In apply_qk_norm, only call the fused kernel on NVIDIA CUDA (q.is_cuda and torch.version.cuda is not None) and valid layouts; if it raises "QK-norm is not applicable", fall back to the RMSNorm path; re-raise other errors.

Checklist

@RubiaCx RubiaCx requested a review from BBuf January 3, 2026 05:08
@RubiaCx RubiaCx added run-ci diffusion SGLang Diffusion labels Jan 3, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @RubiaCx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the stability and compatibility of the fused QK-norm kernel by implementing a robust fallback mechanism. It ensures that the QK-norm operation functions reliably across various hardware and tensor configurations, specifically addressing issues with non-NVIDIA CUDA environments and non-applicable tensor layouts. The changes streamline the execution flow by conditionally applying the optimized fused kernel and reverting to standard PyTorch normalization when the fused path is not viable, thereby preventing runtime errors and maintaining consistent behavior. Additionally, it includes minor cleanups and optimizations in the Qwen image transformer model.

Highlights

  • Robustness for Fused QK-norm: Introduced a safe fallback mechanism for the fused QK-norm kernel, preventing crashes on non-NVIDIA CUDA devices (like ROCm) or with incompatible tensor layouts.
  • Consistent QK-norm Behavior: Ensures that the QK-norm functionality remains consistent across different environments and tensor configurations when enabled.
  • Conditional Fused Kernel Execution: The fused QK-norm kernel is now only invoked on NVIDIA CUDA devices with valid tensor layouts, utilizing a new can_view_as_bnhd helper function to check for compatible layouts.
  • Graceful Fallback: If the fused kernel encounters a 'QK-norm is not applicable' error, the system gracefully falls back to using standard PyTorch RMSNorm. Other errors are re-raised.
  • Simplified Call Site Logic: The _is_cuda() gating is removed from call sites, simplifying the logic for invoking apply_qk_norm.
  • QwenImageTransformer2DModel Refinements: Removed the to_hashable utility function and updated modulate_index calculation in the forward method for GPU efficiency, along with minor initialization changes for norm_q and norm_k.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a safe fallback for the fused QK-norm kernel, enhancing its robustness on different hardware and layouts. The changes correctly add a try-except block to fall back to a standard PyTorch RMSNorm implementation when the fused kernel is not applicable. Additionally, there are good optimizations in qwen_image.py that move computations from CPU to GPU. My review includes suggestions to improve maintainability by reducing code duplication and making the error handling more robust. I've also pointed out a small opportunity for code simplification in the fallback logic.

Comment on lines +445 to +459
def can_view_as_bnhd(x: torch.Tensor) -> bool:
"""Whether `x` can be viewed as [batch, *, head_dim] without a copy."""
if (
x.dim() < 2
or x.size(0) != batch_size
or x.size(-1) != head_dim
or x.stride(-1) != 1
or x.stride(-2) != head_dim
):
return False
try:
x.view(batch_size, -1, head_dim)
return True
except RuntimeError:
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function can_view_as_bnhd is also defined in python/sglang/srt/models/utils.py. To avoid code duplication and improve maintainability, this function should be defined in a shared utility module and imported into both files.

Comment on lines +480 to +482
except RuntimeError as e:
if "QK-norm is not applicable" not in str(e):
raise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic relies on checking for the substring "QK-norm is not applicable" in the RuntimeError message. This is fragile and could break if the error message from the underlying JIT kernel changes in the future. It would be more robust to catch a custom, specific exception type raised by the kernel. If that's not feasible, this is an accepted risk, but worth noting.

Comment on lines +488 to +493
if allow_inplace and q.is_contiguous() and q_out.shape == q.shape:
q.copy_(q_out)
q_out = q
if allow_inplace and k.is_contiguous() and k_out.shape == k.shape:
k.copy_(k_out)
k_out = k
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shape checks q_out.shape == q.shape and k_out.shape == k.shape are likely redundant. RMSNorm preserves the tensor shape, so the output shape of q_norm(q) and k_norm(k) should always be the same as their respective inputs. You could simplify this by removing these shape checks for clarity.

Suggested change
if allow_inplace and q.is_contiguous() and q_out.shape == q.shape:
q.copy_(q_out)
q_out = q
if allow_inplace and k.is_contiguous() and k_out.shape == k.shape:
k.copy_(k_out)
k_out = k
if allow_inplace and q.is_contiguous():
q.copy_(q_out)
q_out = q
if allow_inplace and k.is_contiguous():
k.copy_(k_out)
k_out = k

Comment on lines +232 to +246
def can_view_as_bnhd(x: torch.Tensor) -> bool:
"""Whether `x` can be viewed as [batch, *, head_dim] without a copy."""
if (
x.dim() < 2
or x.size(0) != batch_size
or x.size(-1) != head_dim
or x.stride(-1) != 1
or x.stride(-2) != head_dim
):
return False
try:
x.view(batch_size, -1, head_dim)
return True
except RuntimeError:
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function can_view_as_bnhd is duplicated from python/sglang/multimodal_gen/runtime/layers/layernorm.py. To avoid code duplication and improve maintainability, this function should be defined in a shared utility module and imported into both files.

Comment on lines +268 to +270
except RuntimeError as e:
if "QK-norm is not applicable" not in str(e):
raise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic relies on checking for the substring "QK-norm is not applicable" in the RuntimeError message. This is fragile and could break if the error message from the underlying JIT kernel changes in the future. It would be more robust to catch a custom, specific exception type raised by the kernel. This same issue exists in python/sglang/multimodal_gen/runtime/layers/layernorm.py.

@RubiaCx
Copy link
Copy Markdown
Collaborator Author

RubiaCx commented Jan 3, 2026

Already sloved by #16287, closed.

@RubiaCx RubiaCx closed this Jan 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant