Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/sglang/jit_kernel/diffusion/triton/npu_fallback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import torch
import torch_npu

NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896


# TODO: remove this when triton ascend bug is fixed
Expand All @@ -18,6 +22,23 @@ def apply_rotary_embedding_native(
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)

if (
cos.dim() == 3
and x.dim() == 3
and x.shape[1] < NPU_ROTARY_MUL_MAX_NUM_HEADS
and x.shape[2] < NPU_ROTARY_MUL_MAX_HEAD_SIZE
):
if cos.size(-1) * 2 == x.size(-1):
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
x = x.unsqueeze(0)
x_embed = torch_npu.npu_rotary_mul(x, cos, sin)
x_embed = x_embed.squeeze(0)
return x_embed

x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
QuantizationConfig,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.platforms import current_platform
from sglang.multimodal_gen.runtime.utils.layerwise_offload import OffloadableDiTMixin
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

Expand Down Expand Up @@ -521,8 +522,12 @@ def _init_freqs(self):
def patchify(
self, x: torch.Tensor, control_camera_latents_input: torch.Tensor | None = None
):
# NOTE(dhyu): avoid slow_conv
x = x.contiguous(memory_format=torch.channels_last_3d)
if current_platform.is_npu:
# torch.channels_last_3d is not supported on NPU
x = x.contiguous()
else:
# NOTE(dhyu): avoid slow_conv
x = x.contiguous(memory_format=torch.channels_last_3d)
x = self.patch_embedding(x)
grid_size = x.shape[2:]
x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from sglang.multimodal_gen.utils import PRECISION_TO_TYPE
from sglang.srt.utils.common import get_compiler_backend

_is_npu = current_platform.is_npu()
logger = init_logger(__name__)


Expand Down Expand Up @@ -714,7 +715,14 @@ def inference_single_step(

# Build visual freqs for full sequence
visual_dit._init_freqs()
visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs)
if _is_npu:
# TODO: remove this when torch.complex128 is supported for torch.cat on NPU
visual_freqs = tuple(
freq.to(device=visual_x.device, dtype=torch.complex64)
for freq in visual_dit.freqs
)
else:
visual_freqs = tuple(freq.to(visual_x.device) for freq in visual_dit.freqs)
visual_freqs = (
torch.cat(
[
Expand All @@ -734,18 +742,24 @@ def inference_single_step(

# Build audio freqs for full sequence
self.audio_dit._init_freqs()
audio_freqs = (
torch.cat(
[
self.audio_dit.freqs[0][:f].view(f, -1).expand(f, -1),
self.audio_dit.freqs[1][:f].view(f, -1).expand(f, -1),
self.audio_dit.freqs[2][:f].view(f, -1).expand(f, -1),
],
dim=-1,
if _is_npu:
# TODO: remove this when torch.complex128 is supported for torch.cat on NPU
audio_freqs = tuple(
freq.to(device=audio_x.device, dtype=torch.complex64)
for freq in self.audio_dit.freqs
)
.reshape(full_audio_seq_len, 1, -1)
.to(audio_x.device)
)
else:
audio_freqs = tuple(
freq.to(audio_x.device) for freq in self.audio_dit.freqs
)
audio_freqs = torch.cat(
[
audio_freqs[0][:f].view(f, -1).expand(f, -1),
audio_freqs[1][:f].view(f, -1).expand(f, -1),
audio_freqs[2][:f].view(f, -1).expand(f, -1),
],
dim=-1,
).reshape(full_audio_seq_len, 1, -1)

# Shard sequences for SP
visual_x, visual_pad_len = self._shard_sequence_for_sp(visual_x, dim=1)
Expand Down
Loading