Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions python/sglang/multimodal_gen/configs/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,10 @@ class FluxArchConfig(DiTArchConfig):
guidance_embeds: bool = False
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)

stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".to_qkv", ".to_q", "q"),
(".to_qkv", ".to_k", "k"),
(".to_qkv", ".to_v", "v"),
(".to_added_qkv", ".add_q_proj", "q"),
(".to_added_qkv", ".add_k_proj", "k"),
(".to_added_qkv", ".add_v_proj", "v"),
]
)
stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)

param_names_mapping: dict = field(
default_factory=lambda: {
# QKV fusion mappings
r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3),
r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3),
r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3),
r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3),
r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3),
r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3),
r"transformer\.(\w*)\.(.*)$": r"\1.\2",
}
)
Expand Down
6 changes: 0 additions & 6 deletions python/sglang/multimodal_gen/configs/models/dits/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,13 @@ class ZImageArchConfig(DiTArchConfig):
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".to_qkv", ".to_q", "q"),
(".to_qkv", ".to_k", "k"),
(".to_qkv", ".to_v", "v"),
(".feed_forward.w13", ".feed_forward.w1", "gate"),
(".feed_forward.w13", ".feed_forward.w3", "up"),
]
)

param_names_mapping: dict = field(
default_factory=lambda: {
r"(.*)\.to_q\.weight$": (r"\1.to_qkv.weight", 0, 3),
r"(.*)\.to_k\.weight$": (r"\1.to_qkv.weight", 1, 3),
r"(.*)\.to_v\.weight$": (r"\1.to_qkv.weight", 2, 3),
r"(.*)\.feed_forward\.w1\.weight$": (r"\1.feed_forward.w13.weight", 0, 2),
r"(.*)\.feed_forward\.w3\.weight$": (r"\1.feed_forward.w13.weight", 1, 2),
}
Expand Down
36 changes: 28 additions & 8 deletions python/sglang/multimodal_gen/runtime/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@
def _get_qkv_projections(
attn: "FluxAttention", hidden_states, encoder_hidden_states=None
):
qkv, _ = attn.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)
query, _ = attn.to_q(hidden_states)
key, _ = attn.to_k(hidden_states)
value, _ = attn.to_v(hidden_states)

encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
added_qkv, _ = attn.to_added_qkv(encoder_hidden_states)
encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1)
encoder_query, _ = attn.add_q_proj(encoder_hidden_states)
encoder_key, _ = attn.add_k_proj(encoder_hidden_states)
encoder_value, _ = attn.add_v_proj(encoder_hidden_states)

return query, key, value, encoder_query, encoder_key, encoder_value

Expand Down Expand Up @@ -96,8 +98,14 @@ def __init__(
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)

self.to_qkv = ColumnParallelLinear(
query_dim, self.inner_dim * 3, bias=bias, gather_output=True
self.to_q = ColumnParallelLinear(
query_dim, self.inner_dim, bias=bias, gather_output=True
)
self.to_k = ColumnParallelLinear(
query_dim, self.inner_dim, bias=bias, gather_output=True
)
self.to_v = ColumnParallelLinear(
query_dim, self.inner_dim, bias=bias, gather_output=True
)

if not self.pre_only:
Expand All @@ -113,9 +121,21 @@ def __init__(
if added_kv_proj_dim is not None:
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
self.to_added_qkv = ColumnParallelLinear(
self.add_q_proj = ColumnParallelLinear(
added_kv_proj_dim,
self.inner_dim,
bias=added_proj_bias,
gather_output=True,
)
self.add_k_proj = ColumnParallelLinear(
added_kv_proj_dim,
self.inner_dim,
bias=added_proj_bias,
gather_output=True,
)
self.add_v_proj = ColumnParallelLinear(
added_kv_proj_dim,
self.inner_dim * 3,
self.inner_dim,
bias=added_proj_bias,
gather_output=True,
)
Expand Down
26 changes: 17 additions & 9 deletions python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@
def _get_qkv_projections(
attn: "Flux2Attention", hidden_states, encoder_hidden_states=None
):
qkv, _ = attn.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)
query, _ = attn.to_q(hidden_states)
key, _ = attn.to_k(hidden_states)
value, _ = attn.to_v(hidden_states)

encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
added_qkv, _ = attn.to_added_qkv(encoder_hidden_states)
encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1)
encoder_query, _ = attn.add_q_proj(encoder_hidden_states)
encoder_key, _ = attn.add_k_proj(encoder_hidden_states)
encoder_value, _ = attn.add_v_proj(encoder_hidden_states)

return query, key, value, encoder_query, encoder_key, encoder_value

Expand Down Expand Up @@ -120,8 +122,9 @@ def __init__(
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias

# Use ReplicatedLinear for fused QKV projections
self.to_qkv = ReplicatedLinear(query_dim, self.inner_dim * 3, bias=bias)
self.to_q = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)
self.to_k = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)
self.to_v = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)

# QK Norm
self.norm_q = RMSNorm(dim_head, eps=eps)
Expand All @@ -134,9 +137,14 @@ def __init__(
if added_kv_proj_dim is not None:
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
# Use ReplicatedLinear for added (encoder) QKV projections
self.to_added_qkv = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim * 3, bias=added_proj_bias
self.add_q_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_k_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_v_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)

Expand Down
29 changes: 12 additions & 17 deletions python/sglang/multimodal_gen/runtime/models/dits/zimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from sglang.multimodal_gen.runtime.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
Expand Down Expand Up @@ -110,16 +109,13 @@ def __init__(
super().__init__()
self.dim = dim
self.head_dim = dim // num_heads
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.qk_norm = qk_norm

# Use QKVParallelLinear for QKV projection (fused)
self.to_qkv = QKVParallelLinear(
hidden_size=dim,
head_size=self.head_dim,
total_num_heads=num_heads,
total_num_kv_heads=num_kv_heads,
bias=False,
)
self.to_q = ReplicatedLinear(dim, dim, bias=False)
self.to_k = ReplicatedLinear(dim, self.head_dim * num_kv_heads, bias=False)
self.to_v = ReplicatedLinear(dim, self.head_dim * num_kv_heads, bias=False)

if self.qk_norm:
self.norm_q = RMSNorm(self.head_dim, eps=eps)
Expand All @@ -146,14 +142,13 @@ def forward(
hidden_states: torch.Tensor,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
qkv, _ = self.to_qkv(hidden_states)
q_dim = self.to_qkv.num_heads * self.head_dim
kv_dim = self.to_qkv.num_kv_heads * self.head_dim
q, k, v = torch.split(qkv, [q_dim, kv_dim, kv_dim], dim=-1)

q = q.view(*q.shape[:-1], self.to_qkv.num_heads, self.head_dim)
k = k.view(*k.shape[:-1], self.to_qkv.num_kv_heads, self.head_dim)
v = v.view(*v.shape[:-1], self.to_qkv.num_kv_heads, self.head_dim)
q, _ = self.to_q(hidden_states)
k, _ = self.to_k(hidden_states)
v, _ = self.to_v(hidden_states)

q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
k = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim)
v = v.view(*v.shape[:-1], self.num_kv_heads, self.head_dim)

if self.norm_q is not None:
q = self.norm_q(q)
Expand Down
Loading