diff --git a/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py b/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py index 792c272d76fe..8f102fd73a9f 100644 --- a/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py +++ b/python/sglang/jit_kernel/diffusion/cutedsl/scale_residual_norm_scale_shift.py @@ -257,11 +257,12 @@ def validate_scale_shift(t: torch.Tensor, B: int, S: int, D: int): (t.shape[0] not in (1, B)) or (t.shape[1] not in (1, S) or t.shape[2] != D) ): failed = True - elif t.ndim == 4 and (t.shape[0] != B or t.shape[2] != 1 or t.shape[3] != D): + elif t.ndim == 4: F = t.shape[1] - if S % F != 0: + if t.shape[0] != B or t.shape[2] != 1 or t.shape[3] != D: + failed = True + elif S % F != 0: raise ValueError(f"Validate failed: S({S}) must be divisible by F({F}).") - failed = True if failed: raise ValueError(f"Validate failed: unsupported tensor shape: {t.shape}.") if t.stride()[-1] != 1: diff --git a/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py b/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py index f01ea841d88a..c507a9b600f7 100644 --- a/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py +++ b/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py @@ -20,6 +20,9 @@ def fuse_scale_shift_native( def apply_rotary_embedding_native( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False ) -> torch.Tensor: + if interleaved and cos.shape[-1] == x.shape[-1]: + cos = cos[..., ::2] + sin = sin[..., ::2] cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) diff --git a/python/sglang/jit_kernel/diffusion/triton/torch_fallback.py b/python/sglang/jit_kernel/diffusion/triton/torch_fallback.py index f4350103388a..de115747ba78 100644 --- a/python/sglang/jit_kernel/diffusion/triton/torch_fallback.py +++ b/python/sglang/jit_kernel/diffusion/triton/torch_fallback.py @@ -48,6 +48,9 @@ def apply_rotary_embedding_native( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False ) -> torch.Tensor: """Native fallback for rotary embedding (shared with NPU implementation).""" + if interleaved and cos.shape[-1] == x.shape[-1]: + cos = cos[..., ::2] + sin = sin[..., ::2] cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) x1 = x[..., ::2] diff --git a/python/sglang/jit_kernel/tests/diffusion/test_fused_norm_scale_shift.py b/python/sglang/jit_kernel/tests/diffusion/test_fused_norm_scale_shift.py index 42c3371ff890..87cf70a5b7f8 100644 --- a/python/sglang/jit_kernel/tests/diffusion/test_fused_norm_scale_shift.py +++ b/python/sglang/jit_kernel/tests/diffusion/test_fused_norm_scale_shift.py @@ -9,6 +9,7 @@ from sglang.jit_kernel.diffusion.cutedsl.scale_residual_norm_scale_shift import ( fused_norm_scale_shift, fused_scale_residual_norm_scale_shift, + validate_scale_shift, ) from sglang.test.ci.ci_register import register_cuda_ci @@ -125,6 +126,16 @@ def _make_tensor(index_mode: str, shape: Tuple, dtype: torch.dtype): return torch.randn(*SHAPE_MAP[index_mode](*shape), device=DEVICE, dtype=dtype) +def test_validate_scale_shift_rejects_non_divisible_frames(): + with pytest.raises(ValueError, match=r"S\(10\) must be divisible by F\(4\)"): + validate_scale_shift( + torch.empty((1, 4, 1, 256), device=DEVICE, dtype=torch.float16), + 1, + 10, + 256, + ) + + @torch.no_grad() def run_norm_scale_shift( shape=SHAPES[0],