diff --git a/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py b/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py index 9cb65a9d0b2f..534ed47eac67 100644 --- a/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py +++ b/python/sglang/jit_kernel/diffusion/triton/npu_fallback.py @@ -1,4 +1,8 @@ import torch +import torch_npu + +NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000 +NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896 # TODO: remove this when triton ascend bug is fixed @@ -18,6 +22,23 @@ def apply_rotary_embedding_native( ) -> torch.Tensor: cos = cos.unsqueeze(-2).to(x.dtype) sin = sin.unsqueeze(-2).to(x.dtype) + + if ( + cos.dim() == 3 + and x.dim() == 3 + and x.shape[1] < NPU_ROTARY_MUL_MAX_NUM_HEADS + and x.shape[2] < NPU_ROTARY_MUL_MAX_HEAD_SIZE + ): + if cos.size(-1) * 2 == x.size(-1): + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) + x = x.unsqueeze(0) + x_embed = torch_npu.npu_rotary_mul(x, cos, sin) + x_embed = x_embed.squeeze(0) + return x_embed + x1 = x[..., ::2] x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin diff --git a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py index 21de5b9b37bb..f6f520690e85 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/mova_video_dit.py @@ -34,6 +34,7 @@ QuantizationConfig, ) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import current_platform from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -521,8 +522,12 @@ def _init_freqs(self): def patchify( self, x: torch.Tensor, control_camera_latents_input: torch.Tensor | None = None ): - # NOTE(dhyu): avoid slow_conv - x = x.contiguous(memory_format=torch.channels_last_3d) + if current_platform.is_npu: + # torch.channels_last_3d is not supported on NPU + x = x.contiguous() + else: + # NOTE(dhyu): avoid slow_conv + x = x.contiguous(memory_format=torch.channels_last_3d) x = self.patch_embedding(x) grid_size = x.shape[2:] x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py index afd43238e20a..912291bcd1cc 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py @@ -69,6 +69,7 @@ from sglang.multimodal_gen.utils import PRECISION_TO_TYPE from sglang.srt.utils.common import get_compiler_backend +_is_npu = current_platform.is_npu() logger = init_logger(__name__) @@ -714,7 +715,14 @@ def inference_single_step( # Build visual freqs for full sequence visual_dit._init_freqs() - visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs) + if _is_npu: + # TODO: remove this when torch.complex128 is supported for torch.cat on NPU + visual_freqs = tuple( + freq.to(device=visual_x.device, dtype=torch.complex64) + for freq in visual_dit.freqs + ) + else: + visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs) visual_freqs = ( torch.cat( [ @@ -734,18 +742,24 @@ def inference_single_step( # Build audio freqs for full sequence self.audio_dit._init_freqs() - audio_freqs = ( - torch.cat( - [ - self.audio_dit.freqs[0][:f].view(f, -1).expand(f, -1), - self.audio_dit.freqs[1][:f].view(f, -1).expand(f, -1), - self.audio_dit.freqs[2][:f].view(f, -1).expand(f, -1), - ], - dim=-1, + if _is_npu: + # TODO: remove this when torch.complex128 is supported for torch.cat on NPU + audio_freqs = tuple( + freq.to(device=audio_x.device, dtype=torch.complex64) + for freq in self.audio_dit.freqs ) - .reshape(full_audio_seq_len, 1, -1) - .to(audio_x.device) - ) + else: + audio_freqs = tuple( + freq.to(audio_x.device) for freq in self.audio_dit.freqs + ) + audio_freqs = torch.cat( + [ + audio_freqs[0][:f].view(f, -1).expand(f, -1), + audio_freqs[1][:f].view(f, -1).expand(f, -1), + audio_freqs[2][:f].view(f, -1).expand(f, -1), + ], + dim=-1, + ).reshape(full_audio_seq_len, 1, -1) # Shard sequences for SP visual_x, visual_pad_len = self._shard_sequence_for_sp(visual_x, dim=1)