Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/sglang/multimodal_gen/runtime/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def __init__(self, theta: int, axes_dim: List[int]):

def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
pos = ids.float()
# freqs_cos, freqs_sin = self.rope.forward(positions=pos)
# TODO: potential error: flux use n_axes = ids.shape[-1]
# see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509
freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos)
return freqs_cos.contiguous().float(), freqs_sin.contiguous().float()

Expand Down
57 changes: 23 additions & 34 deletions python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.embeddings import (
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.normalization import AdaLayerNormContinuous

from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig
from sglang.multimodal_gen.runtime.layers.attention import USPAttention
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.rotary_embedding import _apply_rotary_emb
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
NDRotaryEmbedding,
_apply_rotary_emb,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.models.dits.utils import (
delete_projection_layers,
fuse_linear_projections,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -627,35 +629,22 @@ def forward(


class Flux2PosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: list[int]):
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.rope = NDRotaryEmbedding(
rope_dim_list=axes_dim,
rope_theta=theta,
use_real=False,
repeat_interleave_real=False,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
)

def forward(self, ids: torch.Tensor) -> torch.Tensor:
# Expected ids shape: [S, len(self.axes_dim)]
cos_out = []
sin_out = []
def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
for i in range(len(self.axes_dim)):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[..., i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
# TODO: potential error: flux use n_axes = ids.shape[-1]
# see: https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/models/transformers/transformer_flux.py#L509
freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos)
return freqs_cos.contiguous().float(), freqs_sin.contiguous().float()


class Flux2Transformer2DModel(CachableDiT):
Expand Down
Loading