-
Notifications
You must be signed in to change notification settings - Fork 5.1k
[Diffusion] safe fallback for fused QK-norm #16329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -434,31 +434,62 @@ def apply_qk_norm( | |||||||||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||
| """Apply QK normalization for query and key tensors. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Minimal multimodal_gen-only implementation: only the JIT fused inplace | ||||||||||||||||||||||||||
| QK-norm kernel path is supported (no fallback). | ||||||||||||||||||||||||||
| Prefer the fused inplace QK-norm kernel when applicable; | ||||||||||||||||||||||||||
| otherwise fall back to PyTorch norms. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| batch_size = q.size(0) | ||||||||||||||||||||||||||
| q_eps = q_norm.variance_epsilon | ||||||||||||||||||||||||||
| k_eps = k_norm.variance_epsilon | ||||||||||||||||||||||||||
| # Only try fused path on CUDA and when it won't introduce implicit copies. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def can_view_as_bnhd(x: torch.Tensor) -> bool: | ||||||||||||||||||||||||||
| """Whether `x` can be viewed as [batch, *, head_dim] without a copy.""" | ||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||
| x.dim() < 2 | ||||||||||||||||||||||||||
| or x.size(0) != batch_size | ||||||||||||||||||||||||||
| or x.size(-1) != head_dim | ||||||||||||||||||||||||||
| or x.stride(-1) != 1 | ||||||||||||||||||||||||||
| or x.stride(-2) != head_dim | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| x.view(batch_size, -1, head_dim) | ||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||
| except RuntimeError: | ||||||||||||||||||||||||||
| return False | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||
| q.is_cuda | ||||||||||||||||||||||||||
| and (torch.version.cuda is not None) | ||||||||||||||||||||||||||
| and allow_inplace | ||||||||||||||||||||||||||
| and (q_eps == k_eps) | ||||||||||||||||||||||||||
| and can_view_as_bnhd(q) | ||||||||||||||||||||||||||
| and can_view_as_bnhd(k) | ||||||||||||||||||||||||||
| and can_use_fused_inplace_qknorm(head_dim) | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| fused_inplace_qknorm( | ||||||||||||||||||||||||||
| q=q.view(batch_size, -1, head_dim), | ||||||||||||||||||||||||||
| k=k.view(batch_size, -1, head_dim), | ||||||||||||||||||||||||||
| q_weight=q_norm.weight, | ||||||||||||||||||||||||||
| k_weight=k_norm.weight, | ||||||||||||||||||||||||||
| head_dim=head_dim, | ||||||||||||||||||||||||||
| eps=q_eps, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return q, k | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||
| "apply_qk_norm: fused inplace QK-norm is not applicable " | ||||||||||||||||||||||||||
| "(expected CUDA, contiguous q/k, matching eps, and supported head_dim)" | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| fused_inplace_qknorm( | ||||||||||||||||||||||||||
| q=q.view(batch_size, -1, head_dim), | ||||||||||||||||||||||||||
| k=k.view(batch_size, -1, head_dim), | ||||||||||||||||||||||||||
| q_weight=q_norm.weight, | ||||||||||||||||||||||||||
| k_weight=k_norm.weight, | ||||||||||||||||||||||||||
| head_dim=head_dim, | ||||||||||||||||||||||||||
| eps=q_eps, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return q, k | ||||||||||||||||||||||||||
| except RuntimeError as e: | ||||||||||||||||||||||||||
| if "QK-norm is not applicable" not in str(e): | ||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||
|
Comment on lines
+480
to
+482
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback logic relies on checking for the substring |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Fallback: apply the PyTorch norms. | ||||||||||||||||||||||||||
| q_out = q_norm(q) | ||||||||||||||||||||||||||
| k_out = k_norm(k) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if allow_inplace and q.is_contiguous() and q_out.shape == q.shape: | ||||||||||||||||||||||||||
| q.copy_(q_out) | ||||||||||||||||||||||||||
| q_out = q | ||||||||||||||||||||||||||
| if allow_inplace and k.is_contiguous() and k_out.shape == k.shape: | ||||||||||||||||||||||||||
| k.copy_(k_out) | ||||||||||||||||||||||||||
| k_out = k | ||||||||||||||||||||||||||
|
Comment on lines
+488
to
+493
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The shape checks
Suggested change
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return q_out, k_out | ||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -228,22 +228,46 @@ def apply_qk_norm( | |
| batch_size = q.size(0) | ||
| q_eps = q_norm.variance_epsilon | ||
| k_eps = k_norm.variance_epsilon | ||
|
|
||
| def can_view_as_bnhd(x: torch.Tensor) -> bool: | ||
| """Whether `x` can be viewed as [batch, *, head_dim] without a copy.""" | ||
| if ( | ||
| x.dim() < 2 | ||
| or x.size(0) != batch_size | ||
| or x.size(-1) != head_dim | ||
| or x.stride(-1) != 1 | ||
| or x.stride(-2) != head_dim | ||
| ): | ||
| return False | ||
| try: | ||
| x.view(batch_size, -1, head_dim) | ||
| return True | ||
| except RuntimeError: | ||
| return False | ||
|
Comment on lines
+232
to
+246
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| if ( | ||
| _is_cuda # TODO(dark): have not tested on ROCm or other backends | ||
| and allow_inplace # TODO(dark): this can be relaxed if needed | ||
| and (q_eps == k_eps) # TODO(dark): this can also be relaxed | ||
| _is_cuda | ||
| and (torch.version.cuda is not None) # avoid NVIDIA-only kernel on ROCm | ||
| and allow_inplace | ||
| and (q_eps == k_eps) | ||
| and not envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get() | ||
| and can_view_as_bnhd(q) | ||
| and can_view_as_bnhd(k) | ||
| and can_use_fused_inplace_qknorm(head_dim) | ||
| ): | ||
| fused_inplace_qknorm( | ||
| q=q.view(batch_size, -1, head_dim), | ||
| k=k.view(batch_size, -1, head_dim), | ||
| q_weight=q_norm.weight, | ||
| k_weight=k_norm.weight, | ||
| head_dim=head_dim, | ||
| eps=q_eps, | ||
| ) | ||
| return q, k | ||
| try: | ||
| fused_inplace_qknorm( | ||
| q=q.view(batch_size, -1, head_dim), | ||
| k=k.view(batch_size, -1, head_dim), | ||
| q_weight=q_norm.weight, | ||
| k_weight=k_norm.weight, | ||
| head_dim=head_dim, | ||
| eps=q_eps, | ||
| ) | ||
| return q, k | ||
| except RuntimeError as e: | ||
| if "QK-norm is not applicable" not in str(e): | ||
| raise | ||
|
Comment on lines
+268
to
+270
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fallback logic relies on checking for the substring |
||
|
|
||
| if alt_stream is not None and get_is_capture_mode(): | ||
| current_stream = torch.cuda.current_stream() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The helper function
can_view_as_bnhdis also defined inpython/sglang/srt/models/utils.py. To avoid code duplication and improve maintainability, this function should be defined in a shared utility module and imported into both files.