diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 3efb8623b78b..40089bc10d28 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -432,7 +432,8 @@ def __init__(self, theta: int, axes_dim: List[int]): def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: pos = ids.float() - # freqs_cos, freqs_sin = self.rope.forward(positions=pos) + # TODO: potential error: flux use n_axes = ids.shape[-1] + # see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509 freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos) return freqs_cos.contiguous().float(), freqs_sin.contiguous().float() diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index bbb9d9cc9700..4e83e0cb0d79 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -12,28 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn from diffusers.models.attention import AttentionModuleMixin -from diffusers.models.embeddings import ( - TimestepEmbedding, - Timesteps, - get_1d_rotary_pos_embed, -) +from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.normalization import AdaLayerNormContinuous from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig from sglang.multimodal_gen.runtime.layers.attention import USPAttention from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm -from sglang.multimodal_gen.runtime.layers.rotary_embedding import _apply_rotary_emb +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + NDRotaryEmbedding, + _apply_rotary_emb, +) from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.models.dits.utils import ( delete_projection_layers, fuse_linear_projections, ) -from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) # pylint: disable=invalid-name @@ -627,35 +629,22 @@ def forward( class Flux2PosEmbed(nn.Module): - # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: list[int]): + def __init__(self, theta: int, axes_dim: List[int]): super().__init__() - self.theta = theta - self.axes_dim = axes_dim + self.rope = NDRotaryEmbedding( + rope_dim_list=axes_dim, + rope_theta=theta, + use_real=False, + repeat_interleave_real=False, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + ) - def forward(self, ids: torch.Tensor) -> torch.Tensor: - # Expected ids shape: [S, len(self.axes_dim)] - cos_out = [] - sin_out = [] + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] - for i in range(len(self.axes_dim)): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], - pos[..., i], - theta=self.theta, - repeat_interleave_real=True, - use_real=True, - freqs_dtype=freqs_dtype, - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin + # TODO: potential error: flux use n_axes = ids.shape[-1] + # see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509 + freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos) + return freqs_cos.contiguous().float(), freqs_sin.contiguous().float() class Flux2Transformer2DModel(CachableDiT):