Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions python/sglang/multimodal_gen/runtime/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,31 +434,62 @@ def apply_qk_norm(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply QK normalization for query and key tensors.

Minimal multimodal_gen-only implementation: only the JIT fused inplace
QK-norm kernel path is supported (no fallback).
Prefer the fused inplace QK-norm kernel when applicable;
otherwise fall back to PyTorch norms.
"""

batch_size = q.size(0)
q_eps = q_norm.variance_epsilon
k_eps = k_norm.variance_epsilon
# Only try fused path on CUDA and when it won't introduce implicit copies.

def can_view_as_bnhd(x: torch.Tensor) -> bool:
"""Whether `x` can be viewed as [batch, *, head_dim] without a copy."""
if (
x.dim() < 2
or x.size(0) != batch_size
or x.size(-1) != head_dim
or x.stride(-1) != 1
or x.stride(-2) != head_dim
):
return False
try:
x.view(batch_size, -1, head_dim)
return True
except RuntimeError:
return False
Comment on lines +445 to +459
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function can_view_as_bnhd is also defined in python/sglang/srt/models/utils.py. To avoid code duplication and improve maintainability, this function should be defined in a shared utility module and imported into both files.


if (
q.is_cuda
and (torch.version.cuda is not None)
and allow_inplace
and (q_eps == k_eps)
and can_view_as_bnhd(q)
and can_view_as_bnhd(k)
and can_use_fused_inplace_qknorm(head_dim)
):
fused_inplace_qknorm(
q=q.view(batch_size, -1, head_dim),
k=k.view(batch_size, -1, head_dim),
q_weight=q_norm.weight,
k_weight=k_norm.weight,
head_dim=head_dim,
eps=q_eps,
)
return q, k

raise RuntimeError(
"apply_qk_norm: fused inplace QK-norm is not applicable "
"(expected CUDA, contiguous q/k, matching eps, and supported head_dim)"
)
try:
fused_inplace_qknorm(
q=q.view(batch_size, -1, head_dim),
k=k.view(batch_size, -1, head_dim),
q_weight=q_norm.weight,
k_weight=k_norm.weight,
head_dim=head_dim,
eps=q_eps,
)
return q, k
except RuntimeError as e:
if "QK-norm is not applicable" not in str(e):
raise
Comment on lines +480 to +482
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic relies on checking for the substring "QK-norm is not applicable" in the RuntimeError message. This is fragile and could break if the error message from the underlying JIT kernel changes in the future. It would be more robust to catch a custom, specific exception type raised by the kernel. If that's not feasible, this is an accepted risk, but worth noting.


# Fallback: apply the PyTorch norms.
q_out = q_norm(q)
k_out = k_norm(k)

if allow_inplace and q.is_contiguous() and q_out.shape == q.shape:
q.copy_(q_out)
q_out = q
if allow_inplace and k.is_contiguous() and k_out.shape == k.shape:
k.copy_(k_out)
k_out = k
Comment on lines +488 to +493
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shape checks q_out.shape == q.shape and k_out.shape == k.shape are likely redundant. RMSNorm preserves the tensor shape, so the output shape of q_norm(q) and k_norm(k) should always be the same as their respective inputs. You could simplify this by removing these shape checks for clarity.

Suggested change
if allow_inplace and q.is_contiguous() and q_out.shape == q.shape:
q.copy_(q_out)
q_out = q
if allow_inplace and k.is_contiguous() and k_out.shape == k.shape:
k.copy_(k_out)
k_out = k
if allow_inplace and q.is_contiguous():
q.copy_(q_out)
q_out = q
if allow_inplace and k.is_contiguous():
k.copy_(k_out)
k_out = k


return q_out, k_out
23 changes: 14 additions & 9 deletions python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from math import prod
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -485,6 +486,9 @@ def __init__(
if self.qk_norm:
self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
else:
self.norm_q = None
self.norm_k = None

if added_kv_proj_dim is not None:
self.add_q_proj = ReplicatedLinear(
Expand Down Expand Up @@ -793,12 +797,6 @@ def forward(
return encoder_hidden_states, hidden_states


def to_hashable(obj):
if isinstance(obj, list):
return tuple(to_hashable(x) for x in obj)
return obj


class QwenImageTransformer2DModel(CachableDiT, OffloadableDiTMixin):
"""
The Transformer model introduced in Qwen.
Expand Down Expand Up @@ -948,9 +946,16 @@ def forward(
timestep = (timestep / 1000).to(hidden_states.dtype)

if self.zero_cond_t:
timestep = torch.cat([timestep, self.timestep_zero], dim=0)
device = timestep.device
modulate_index = self.build_modulate_index(to_hashable(img_shapes), device)
timestep = torch.cat([timestep, timestep * 0], dim=0)
# Use torch operations for GPU efficiency
modulate_index = torch.tensor(
[
[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]])
for sample in img_shapes
],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None

Expand Down
48 changes: 36 additions & 12 deletions python/sglang/srt/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,22 +228,46 @@ def apply_qk_norm(
batch_size = q.size(0)
q_eps = q_norm.variance_epsilon
k_eps = k_norm.variance_epsilon

def can_view_as_bnhd(x: torch.Tensor) -> bool:
"""Whether `x` can be viewed as [batch, *, head_dim] without a copy."""
if (
x.dim() < 2
or x.size(0) != batch_size
or x.size(-1) != head_dim
or x.stride(-1) != 1
or x.stride(-2) != head_dim
):
return False
try:
x.view(batch_size, -1, head_dim)
return True
except RuntimeError:
return False
Comment on lines +232 to +246
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function can_view_as_bnhd is duplicated from python/sglang/multimodal_gen/runtime/layers/layernorm.py. To avoid code duplication and improve maintainability, this function should be defined in a shared utility module and imported into both files.


if (
_is_cuda # TODO(dark): have not tested on ROCm or other backends
and allow_inplace # TODO(dark): this can be relaxed if needed
and (q_eps == k_eps) # TODO(dark): this can also be relaxed
_is_cuda
and (torch.version.cuda is not None) # avoid NVIDIA-only kernel on ROCm
and allow_inplace
and (q_eps == k_eps)
and not envs.SGLANG_ENABLE_DETERMINISTIC_INFERENCE.get()
and can_view_as_bnhd(q)
and can_view_as_bnhd(k)
and can_use_fused_inplace_qknorm(head_dim)
):
fused_inplace_qknorm(
q=q.view(batch_size, -1, head_dim),
k=k.view(batch_size, -1, head_dim),
q_weight=q_norm.weight,
k_weight=k_norm.weight,
head_dim=head_dim,
eps=q_eps,
)
return q, k
try:
fused_inplace_qknorm(
q=q.view(batch_size, -1, head_dim),
k=k.view(batch_size, -1, head_dim),
q_weight=q_norm.weight,
k_weight=k_norm.weight,
head_dim=head_dim,
eps=q_eps,
)
return q, k
except RuntimeError as e:
if "QK-norm is not applicable" not in str(e):
raise
Comment on lines +268 to +270
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic relies on checking for the substring "QK-norm is not applicable" in the RuntimeError message. This is fragile and could break if the error message from the underlying JIT kernel changes in the future. It would be more robust to catch a custom, specific exception type raised by the kernel. This same issue exists in python/sglang/multimodal_gen/runtime/layers/layernorm.py.


if alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
Expand Down