diff --git a/python/sglang/multimodal_gen/configs/models/dits/flux.py b/python/sglang/multimodal_gen/configs/models/dits/flux.py index ee5a9867a7dd..fde2eddccefd 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/flux.py +++ b/python/sglang/multimodal_gen/configs/models/dits/flux.py @@ -21,27 +21,10 @@ class FluxArchConfig(DiTArchConfig): guidance_embeds: bool = False axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) - stacked_params_mapping: list[tuple[str, str, str]] = field( - default_factory=lambda: [ - # (param_name, shard_name, shard_id) - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), - (".to_added_qkv", ".add_q_proj", "q"), - (".to_added_qkv", ".add_k_proj", "k"), - (".to_added_qkv", ".add_v_proj", "v"), - ] - ) + stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list) param_names_mapping: dict = field( default_factory=lambda: { - # QKV fusion mappings - r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), - r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), - r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), - r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), - r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), - r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), r"transformer\.(\w*)\.(.*)$": r"\1.\2", } ) diff --git a/python/sglang/multimodal_gen/configs/models/dits/zimage.py b/python/sglang/multimodal_gen/configs/models/dits/zimage.py index fc0b178307a8..de32758526fc 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/zimage.py +++ b/python/sglang/multimodal_gen/configs/models/dits/zimage.py @@ -29,9 +29,6 @@ class ZImageArchConfig(DiTArchConfig): stacked_params_mapping: list[tuple[str, str, str]] = field( default_factory=lambda: [ # (param_name, shard_name, shard_id) - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), (".feed_forward.w13", ".feed_forward.w1", "gate"), (".feed_forward.w13", ".feed_forward.w3", "up"), ] @@ -39,9 +36,6 @@ class ZImageArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { - r"(.*)\.to_q\.weight$": (r"\1.to_qkv.weight", 0, 3), - r"(.*)\.to_k\.weight$": (r"\1.to_qkv.weight", 1, 3), - r"(.*)\.to_v\.weight$": (r"\1.to_qkv.weight", 2, 3), r"(.*)\.feed_forward\.w1\.weight$": (r"\1.feed_forward.w13.weight", 0, 2), r"(.*)\.feed_forward\.w3\.weight$": (r"\1.feed_forward.w13.weight", 1, 2), } diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index c0d1f991a2f5..dc88afa50d6c 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -52,13 +52,15 @@ def _get_qkv_projections( attn: "FluxAttention", hidden_states, encoder_hidden_states=None ): - qkv, _ = attn.to_qkv(hidden_states) - query, key, value = qkv.chunk(3, dim=-1) + query, _ = attn.to_q(hidden_states) + key, _ = attn.to_k(hidden_states) + value, _ = attn.to_v(hidden_states) encoder_query = encoder_key = encoder_value = None if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - added_qkv, _ = attn.to_added_qkv(encoder_hidden_states) - encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) + encoder_query, _ = attn.add_q_proj(encoder_hidden_states) + encoder_key, _ = attn.add_k_proj(encoder_hidden_states) + encoder_value, _ = attn.add_v_proj(encoder_hidden_states) return query, key, value, encoder_query, encoder_key, encoder_value @@ -96,8 +98,14 @@ def __init__( self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) - self.to_qkv = ColumnParallelLinear( - query_dim, self.inner_dim * 3, bias=bias, gather_output=True + self.to_q = ColumnParallelLinear( + query_dim, self.inner_dim, bias=bias, gather_output=True + ) + self.to_k = ColumnParallelLinear( + query_dim, self.inner_dim, bias=bias, gather_output=True + ) + self.to_v = ColumnParallelLinear( + query_dim, self.inner_dim, bias=bias, gather_output=True ) if not self.pre_only: @@ -113,9 +121,21 @@ def __init__( if added_kv_proj_dim is not None: self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) - self.to_added_qkv = ColumnParallelLinear( + self.add_q_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + ) + self.add_k_proj = ColumnParallelLinear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + gather_output=True, + ) + self.add_v_proj = ColumnParallelLinear( added_kv_proj_dim, - self.inner_dim * 3, + self.inner_dim, bias=added_proj_bias, gather_output=True, ) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index 636ff6158e4a..b99f893b9189 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -38,13 +38,15 @@ def _get_qkv_projections( attn: "Flux2Attention", hidden_states, encoder_hidden_states=None ): - qkv, _ = attn.to_qkv(hidden_states) - query, key, value = qkv.chunk(3, dim=-1) + query, _ = attn.to_q(hidden_states) + key, _ = attn.to_k(hidden_states) + value, _ = attn.to_v(hidden_states) encoder_query = encoder_key = encoder_value = None if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - added_qkv, _ = attn.to_added_qkv(encoder_hidden_states) - encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) + encoder_query, _ = attn.add_q_proj(encoder_hidden_states) + encoder_key, _ = attn.add_k_proj(encoder_hidden_states) + encoder_value, _ = attn.add_v_proj(encoder_hidden_states) return query, key, value, encoder_query, encoder_key, encoder_value @@ -120,8 +122,9 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias - # Use ReplicatedLinear for fused QKV projections - self.to_qkv = ReplicatedLinear(query_dim, self.inner_dim * 3, bias=bias) + self.to_q = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) + self.to_k = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) + self.to_v = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) # QK Norm self.norm_q = RMSNorm(dim_head, eps=eps) @@ -134,9 +137,14 @@ def __init__( if added_kv_proj_dim is not None: self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) - # Use ReplicatedLinear for added (encoder) QKV projections - self.to_added_qkv = ReplicatedLinear( - added_kv_proj_dim, self.inner_dim * 3, bias=added_proj_bias + self.add_q_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias + ) + self.add_k_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias + ) + self.add_v_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias ) self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py index 73f3061fccd3..9fb166b69cb9 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/zimage.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/zimage.py @@ -11,7 +11,6 @@ from sglang.multimodal_gen.runtime.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) @@ -110,16 +109,13 @@ def __init__( super().__init__() self.dim = dim self.head_dim = dim // num_heads + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads self.qk_norm = qk_norm - # Use QKVParallelLinear for QKV projection (fused) - self.to_qkv = QKVParallelLinear( - hidden_size=dim, - head_size=self.head_dim, - total_num_heads=num_heads, - total_num_kv_heads=num_kv_heads, - bias=False, - ) + self.to_q = ReplicatedLinear(dim, dim, bias=False) + self.to_k = ReplicatedLinear(dim, self.head_dim * num_kv_heads, bias=False) + self.to_v = ReplicatedLinear(dim, self.head_dim * num_kv_heads, bias=False) if self.qk_norm: self.norm_q = RMSNorm(self.head_dim, eps=eps) @@ -146,14 +142,13 @@ def forward( hidden_states: torch.Tensor, freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): - qkv, _ = self.to_qkv(hidden_states) - q_dim = self.to_qkv.num_heads * self.head_dim - kv_dim = self.to_qkv.num_kv_heads * self.head_dim - q, k, v = torch.split(qkv, [q_dim, kv_dim, kv_dim], dim=-1) - - q = q.view(*q.shape[:-1], self.to_qkv.num_heads, self.head_dim) - k = k.view(*k.shape[:-1], self.to_qkv.num_kv_heads, self.head_dim) - v = v.view(*v.shape[:-1], self.to_qkv.num_kv_heads, self.head_dim) + q, _ = self.to_q(hidden_states) + k, _ = self.to_k(hidden_states) + v, _ = self.to_v(hidden_states) + + q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) + k = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) + v = v.view(*v.shape[:-1], self.num_kv_heads, self.head_dim) if self.norm_q is not None: q = self.norm_q(q)