From d564a06cf3707ad1b071129e59e49d66eb21ff78 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 13 Dec 2025 22:16:33 +0800 Subject: [PATCH 01/13] add multimodal gen profiling doc --- python/sglang/multimodal_gen/docs/profiling.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/multimodal_gen/docs/profiling.md b/python/sglang/multimodal_gen/docs/profiling.md index 58a72f5fc3e0..b675816da8fd 100644 --- a/python/sglang/multimodal_gen/docs/profiling.md +++ b/python/sglang/multimodal_gen/docs/profiling.md @@ -47,6 +47,10 @@ By default, trace files are saved in the ./logs/ directory. The exact output fil ```bash [mm-dd hh:mm:ss] Saving profiler traces to: /sgl-workspace/sglang/logs/mocked_fake_id_for_offline_generate-5_steps-global-rank0.trace.json.gz ``` +{request_id}-{num_steps}_steps-global-rank{rank}.trace.json.gz +``` + +Example: `mocked_fake_id_for_offline_generate-5_steps-global-rank0.trace.json.gz` ### View Traces From 430095afea5cb75a124af2a164e53383541aa474 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 13 Dec 2025 22:47:16 +0800 Subject: [PATCH 02/13] refactor fuse qkv with QKVParallelLinear --- .../runtime/models/dits/flux.py | 139 ++++++++--------- .../runtime/models/dits/flux_2.py | 133 ++++++++-------- .../runtime/models/dits/qwen_image.py | 144 +++++++++--------- .../runtime/models/dits/utils.py | 43 ------ 4 files changed, 194 insertions(+), 265 deletions(-) delete mode 100644 python/sglang/multimodal_gen/runtime/models/dits/utils.py diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 40089bc10d28..334c4e277c6c 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -36,17 +37,17 @@ # from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm as LayerNorm from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm -from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, +) from sglang.multimodal_gen.runtime.layers.mlp import MLP from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( NDRotaryEmbedding, _apply_rotary_emb, ) +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT -from sglang.multimodal_gen.runtime.models.dits.utils import ( - delete_projection_layers, - fuse_linear_projections, -) from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, current_platform, @@ -56,45 +57,21 @@ logger = init_logger(__name__) # pylint: disable=invalid-name -def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): - query, _ = attn.to_q(hidden_states) - key, _ = attn.to_k(hidden_states) - value, _ = attn.to_v(hidden_states) - - encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query, _ = attn.add_q_proj(encoder_hidden_states) - encoder_key, _ = attn.add_k_proj(encoder_hidden_states) - encoder_value, _ = attn.add_v_proj(encoder_hidden_states) - - return query, key, value, encoder_query, encoder_key, encoder_value - - -def _get_fused_projections( +def _get_qkv_projections( attn: "FluxAttention", hidden_states, encoder_hidden_states=None ): qkv, _ = attn.to_qkv(hidden_states) query, key, value = qkv.chunk(3, dim=-1) encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: added_qkv, _ = attn.to_added_qkv(encoder_hidden_states) encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value -def _get_qkv_projections( - attn: "FluxAttention", hidden_states, encoder_hidden_states=None -): - if attn.fused_projections: - return _get_fused_projections(attn, hidden_states, encoder_hidden_states) - return _get_projections(attn, hidden_states, encoder_hidden_states) - - class FluxAttention(torch.nn.Module, AttentionModuleMixin): - _supports_qkv_fusion = True - def __init__( self, query_dim: int, @@ -125,11 +102,15 @@ def __init__( self.added_proj_bias = added_proj_bias self.norm_q = RMSNorm(dim_head, eps=eps) - self.norm_k = RMSNorm(dim_head, eps=eps) - self.to_q = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) - self.to_k = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) - self.to_v = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) + + # Use QKVParallelLinear for fused QKV projections + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=dim_head, + total_num_heads=num_heads, + bias=bias, + ) if not self.pre_only: self.to_out = torch.nn.ModuleList([]) @@ -142,14 +123,12 @@ def __init__( if added_kv_proj_dim is not None: self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) - self.add_q_proj = ReplicatedLinear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias - ) - self.add_k_proj = ReplicatedLinear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias - ) - self.add_v_proj = ReplicatedLinear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias + # Use QKVParallelLinear for added (encoder) QKV projections + self.to_added_qkv = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=dim_head, + total_num_heads=num_heads, + bias=added_proj_bias, ) self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias) @@ -166,30 +145,6 @@ def __init__( }, ) - self.fused_projections = False - - @torch.no_grad() - def fuse_projections(self): - if self.fused_projections: - return - - self.to_qkv = fuse_linear_projections( - self.to_q, self.to_k, self.to_v, self.use_bias, ReplicatedLinear - ) - delete_projection_layers(self, ["to_q", "to_k", "to_v"]) - - if self.added_kv_proj_dim is not None: - self.to_added_qkv = fuse_linear_projections( - self.add_q_proj, - self.add_k_proj, - self.add_v_proj, - self.added_proj_bias, - ReplicatedLinear, - ) - delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"]) - - self.fused_projections = True - def forward( self, x: torch.Tensor, @@ -504,13 +459,49 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: ) def fuse_qkv_projections(self): - for block in list(self.transformer_blocks) + list( - self.single_transformer_blocks - ): - if hasattr(block.attn, "fuse_projections") and getattr( - block.attn, "_supports_qkv_fusion", True - ): - block.attn.fuse_projections() + # QKV projections are already fused in __init__, this method is kept for compatibility + pass + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with mapping for q/k/v -> qkv fusion.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("to_qkv", "to_q", "q"), + ("to_qkv", "to_k", "k"), + ("to_qkv", "to_v", "v"), + ("to_added_qkv", "add_q_proj", "q"), + ("to_added_qkv", "add_k_proj", "k"), + ("to_added_qkv", "add_v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Handle q/k/v -> qkv mapping + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name: + # Replace the weight name with the parameter name + model_param_name = name.replace(weight_name, param_name) + + if model_param_name in params_dict: + param = params_dict[model_param_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(model_param_name) + break + else: + # Use default weight loader for all other parameters + if name in params_dict: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params def forward( self, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index 4e83e0cb0d79..bceefb2b051b 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple import torch @@ -23,15 +24,13 @@ from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig from sglang.multimodal_gen.runtime.layers.attention import USPAttention from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import QKVParallelLinear from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( NDRotaryEmbedding, _apply_rotary_emb, ) +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT -from sglang.multimodal_gen.runtime.models.dits.utils import ( - delete_projection_layers, - fuse_linear_projections, -) from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, current_platform, @@ -41,42 +40,20 @@ logger = init_logger(__name__) # pylint: disable=invalid-name -def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - encoder_query = attn.add_q_proj(encoder_hidden_states) - encoder_key = attn.add_k_proj(encoder_hidden_states) - encoder_value = attn.add_v_proj(encoder_hidden_states) - - return query, key, value, encoder_query, encoder_key, encoder_value - - -def _get_fused_projections( +def _get_qkv_projections( attn: "Flux2Attention", hidden_states, encoder_hidden_states=None ): qkv = attn.to_qkv(hidden_states) query, key, value = qkv.chunk(3, dim=-1) encoder_query = encoder_key = encoder_value = None - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: added_qkv = attn.to_added_qkv(encoder_hidden_states) encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value -def _get_qkv_projections( - attn: "Flux2Attention", hidden_states, encoder_hidden_states=None -): - if attn.fused_projections: - return _get_fused_projections(attn, hidden_states, encoder_hidden_states) - return _get_projections(attn, hidden_states, encoder_hidden_states) - - class Flux2SwiGLU(nn.Module): """ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection @@ -120,8 +97,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Flux2Attention(torch.nn.Module, AttentionModuleMixin): - _supports_qkv_fusion = True - def __init__( self, query_dim: int, @@ -150,9 +125,13 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias - self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + # Use QKVParallelLinear for fused QKV projections + self.to_qkv = QKVParallelLinear( + hidden_size=query_dim, + head_size=dim_head, + total_num_heads=num_heads, + bias=bias, + ) # QK Norm self.norm_q = RMSNorm(dim_head, eps=eps) @@ -165,14 +144,12 @@ def __init__( if added_kv_proj_dim is not None: self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) - self.add_q_proj = torch.nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias - ) - self.add_k_proj = torch.nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias - ) - self.add_v_proj = torch.nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias + # Use QKVParallelLinear for added (encoder) QKV projections + self.to_added_qkv = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=dim_head, + total_num_heads=num_heads, + bias=added_proj_bias, ) self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) @@ -189,30 +166,6 @@ def __init__( }, ) - self.fused_projections = False - - @torch.no_grad() - def fuse_projections(self): - if self.fused_projections: - return - - self.to_qkv = fuse_linear_projections( - self.to_q, self.to_k, self.to_v, self.use_bias, torch.nn.Linear - ) - delete_projection_layers(self, ["to_q", "to_k", "to_v"]) - - if self.added_kv_proj_dim is not None: - self.to_added_qkv = fuse_linear_projections( - self.add_q_proj, - self.add_k_proj, - self.add_v_proj, - self.added_proj_bias, - torch.nn.Linear, - ) - delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"]) - - self.fused_projections = True - def forward( self, hidden_states: torch.Tensor, @@ -747,13 +700,49 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]): self.gradient_checkpointing = False def fuse_qkv_projections(self): - for block in list(self.transformer_blocks) + list( - self.single_transformer_blocks - ): - if hasattr(block.attn, "fuse_projections") and getattr( - block.attn, "_supports_qkv_fusion", True - ): - block.attn.fuse_projections() + # QKV projections are already fused in __init__, this method is kept for compatibility + pass + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with mapping for q/k/v -> qkv fusion.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("to_qkv", "to_q", "q"), + ("to_qkv", "to_k", "k"), + ("to_qkv", "to_v", "v"), + ("to_added_qkv", "add_q_proj", "q"), + ("to_added_qkv", "add_k_proj", "k"), + ("to_added_qkv", "add_v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Handle q/k/v -> qkv mapping + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name: + # Replace the weight name with the parameter name + model_param_name = name.replace(weight_name, param_name) + + if model_param_name in params_dict: + param = params_dict[model_param_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(model_param_name) + break + else: + # Use default weight loader for all other parameters + if name in params_dict: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params def forward( self, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 6808460bca66..9502f458ce06 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -16,60 +17,36 @@ from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig from sglang.multimodal_gen.runtime.layers.attention import USPAttention from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm, RMSNorm -from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, +) from sglang.multimodal_gen.runtime.layers.triton_ops import ( apply_rotary_embedding, fuse_scale_shift_kernel, ) +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT -from sglang.multimodal_gen.runtime.models.dits.utils import ( - delete_projection_layers, - fuse_linear_projections, -) from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger logger = init_logger(__name__) # pylint: disable=invalid-name -def _get_projections( - attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None -): - img_query, _ = attn.to_q(hidden_states) - img_key, _ = attn.to_k(hidden_states) - img_value, _ = attn.to_v(hidden_states) - - txt_query = txt_key = txt_value = None - if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"): - txt_query, _ = attn.add_q_proj(encoder_hidden_states) - txt_key, _ = attn.add_k_proj(encoder_hidden_states) - txt_value, _ = attn.add_v_proj(encoder_hidden_states) - - return img_query, img_key, img_value, txt_query, txt_key, txt_value - - -def _get_fused_projections( +def _get_qkv_projections( attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None ): img_qkv, _ = attn.to_qkv(hidden_states) img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) txt_query = txt_key = txt_value = None - if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: txt_qkv, _ = attn.to_added_qkv(encoder_hidden_states) txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) return img_query, img_key, img_value, txt_query, txt_key, txt_value -def _get_qkv_projections( - attn: "QwenImageCrossAttention", hidden_states, encoder_hidden_states=None -): - if attn.fused_projections: - return _get_fused_projections(attn, hidden_states, encoder_hidden_states) - return _get_projections(attn, hidden_states, encoder_hidden_states) - - class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() @@ -260,8 +237,6 @@ def _compute_video_freqs( class QwenImageCrossAttention(nn.Module): - _supports_qkv_fusion = True - def __init__( self, dim: int, # query_dim @@ -286,27 +261,31 @@ def __init__( self.qk_norm = qk_norm self.eps = eps self.parallel_attention = parallel_attention + self.added_kv_proj_dim = added_kv_proj_dim + + # Use QKVParallelLinear for fused QKV projections + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=num_heads, + bias=False, + ) - # layers - self.to_q = ReplicatedLinear(dim, dim) - self.to_k = ReplicatedLinear(dim, dim) - self.to_v = ReplicatedLinear(dim, dim) if self.qk_norm: self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads self.inner_kv_dim = self.inner_dim + if added_kv_proj_dim is not None: - self.add_k_proj = ReplicatedLinear( - added_kv_proj_dim, self.inner_kv_dim, bias=True - ) - self.add_v_proj = ReplicatedLinear( - added_kv_proj_dim, self.inner_kv_dim, bias=True + # Use QKVParallelLinear for added (encoder) QKV projections + self.to_added_qkv = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=head_dim, + total_num_heads=num_heads, + bias=True, ) - if context_pre_only is not None: - self.add_q_proj = ReplicatedLinear( - added_kv_proj_dim, self.inner_dim, bias=True - ) if context_pre_only is not None and not context_pre_only: self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) @@ -338,31 +317,6 @@ def __init__( }, ) - self.fused_projections = False - self.added_kv_proj_dim_val = added_kv_proj_dim - - @torch.no_grad() - def fuse_projections(self): - if self.fused_projections: - return - - self.to_qkv = fuse_linear_projections( - self.to_q, self.to_k, self.to_v, use_bias=False, linear_cls=ReplicatedLinear - ) - delete_projection_layers(self, ["to_q", "to_k", "to_v"]) - - if self.added_kv_proj_dim_val is not None and hasattr(self, "add_q_proj"): - self.to_added_qkv = fuse_linear_projections( - self.add_q_proj, - self.add_k_proj, - self.add_v_proj, - use_bias=True, - linear_cls=ReplicatedLinear, - ) - delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"]) - - self.fused_projections = True - def forward( self, hidden_states: torch.Tensor, @@ -626,11 +580,49 @@ def __init__( ) def fuse_qkv_projections(self): - for block in self.transformer_blocks: - if hasattr(block.attn, "fuse_projections") and getattr( - block.attn, "_supports_qkv_fusion", True - ): - block.attn.fuse_projections() + # QKV projections are already fused in __init__, this method is kept for compatibility + pass + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with mapping for q/k/v -> qkv fusion.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("to_qkv", "to_q", "q"), + ("to_qkv", "to_k", "k"), + ("to_qkv", "to_v", "v"), + ("to_added_qkv", "add_q_proj", "q"), + ("to_added_qkv", "add_k_proj", "k"), + ("to_added_qkv", "add_v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + # Handle q/k/v -> qkv mapping + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name: + # Replace the weight name with the parameter name + model_param_name = name.replace(weight_name, param_name) + + if model_param_name in params_dict: + param = params_dict[model_param_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(model_param_name) + break + else: + # Use default weight loader for all other parameters + if name in params_dict: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params def forward( self, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/utils.py b/python/sglang/multimodal_gen/runtime/models/dits/utils.py deleted file mode 100644 index 74b2d4d67295..000000000000 --- a/python/sglang/multimodal_gen/runtime/models/dits/utils.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Union - -import torch -import torch.nn as nn - -from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear - - -def fuse_linear_projections( - q_proj: Union[nn.Linear, ReplicatedLinear], - k_proj: Union[nn.Linear, ReplicatedLinear], - v_proj: Union[nn.Linear, ReplicatedLinear], - use_bias: bool, - linear_cls: type = None, -) -> Union[nn.Linear, ReplicatedLinear]: - device = q_proj.weight.data.device - dtype = q_proj.weight.data.dtype - - concatenated_weights = torch.cat( - [q_proj.weight.data, k_proj.weight.data, v_proj.weight.data] - ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - if linear_cls is None: - linear_cls = type(q_proj) - - fused_layer = linear_cls(in_features, out_features, bias=use_bias) - fused_layer.weight.data = concatenated_weights.to(device=device, dtype=dtype) - - if use_bias: - concatenated_bias = torch.cat( - [q_proj.bias.data, k_proj.bias.data, v_proj.bias.data] - ) - fused_layer.bias.data = concatenated_bias.to(device=device, dtype=dtype) - - return fused_layer - - -def delete_projection_layers(module: nn.Module, layer_names: list[str]) -> None: - for name in layer_names: - if hasattr(module, name): - delattr(module, name) From c243607009c1f0b7b879b98c92a8d49359bcc96b Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 13 Dec 2025 23:11:35 +0800 Subject: [PATCH 03/13] ud --- .../runtime/models/dits/flux.py | 37 ++++++++++--------- .../runtime/models/dits/flux_2.py | 37 ++++++++++--------- .../runtime/models/dits/qwen_image.py | 37 ++++++++++--------- 3 files changed, 57 insertions(+), 54 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 334c4e277c6c..2228bfeadc0a 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -466,12 +466,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for q/k/v -> qkv fusion.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("to_qkv", "to_q", "q"), - ("to_qkv", "to_k", "k"), - ("to_qkv", "to_v", "v"), - ("to_added_qkv", "add_q_proj", "q"), - ("to_added_qkv", "add_k_proj", "k"), - ("to_added_qkv", "add_v_proj", "v"), + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".to_added_qkv", ".add_q_proj", "q"), + (".to_added_qkv", ".add_k_proj", "k"), + (".to_added_qkv", ".add_v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -479,18 +479,19 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: for name, loaded_weight in weights: # Handle q/k/v -> qkv mapping for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name in name: - # Replace the weight name with the parameter name - model_param_name = name.replace(weight_name, param_name) - - if model_param_name in params_dict: - param = params_dict[model_param_name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(model_param_name) - break + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break else: # Use default weight loader for all other parameters if name in params_dict: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index bceefb2b051b..507a2bb297b2 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -707,12 +707,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for q/k/v -> qkv fusion.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("to_qkv", "to_q", "q"), - ("to_qkv", "to_k", "k"), - ("to_qkv", "to_v", "v"), - ("to_added_qkv", "add_q_proj", "q"), - ("to_added_qkv", "add_k_proj", "k"), - ("to_added_qkv", "add_v_proj", "v"), + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".to_added_qkv", ".add_q_proj", "q"), + (".to_added_qkv", ".add_k_proj", "k"), + (".to_added_qkv", ".add_v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -720,18 +720,19 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: for name, loaded_weight in weights: # Handle q/k/v -> qkv mapping for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name in name: - # Replace the weight name with the parameter name - model_param_name = name.replace(weight_name, param_name) - - if model_param_name in params_dict: - param = params_dict[model_param_name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(model_param_name) - break + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break else: # Use default weight loader for all other parameters if name in params_dict: diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 9502f458ce06..ff69b0c29643 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -587,12 +587,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for q/k/v -> qkv fusion.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("to_qkv", "to_q", "q"), - ("to_qkv", "to_k", "k"), - ("to_qkv", "to_v", "v"), - ("to_added_qkv", "add_q_proj", "q"), - ("to_added_qkv", "add_k_proj", "k"), - ("to_added_qkv", "add_v_proj", "v"), + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".to_added_qkv", ".add_q_proj", "q"), + (".to_added_qkv", ".add_k_proj", "k"), + (".to_added_qkv", ".add_v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -600,18 +600,19 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: for name, loaded_weight in weights: # Handle q/k/v -> qkv mapping for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name in name: - # Replace the weight name with the parameter name - model_param_name = name.replace(weight_name, param_name) - - if model_param_name in params_dict: - param = params_dict[model_param_name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(model_param_name) - break + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break else: # Use default weight loader for all other parameters if name in params_dict: From d44113430aea9a558e19a3ca755d1937952f76b3 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 13 Dec 2025 23:18:22 +0800 Subject: [PATCH 04/13] ud --- .../configs/models/dits/flux.py | 20 +++++++++++++++++++ .../configs/models/dits/qwenimage.py | 20 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/python/sglang/multimodal_gen/configs/models/dits/flux.py b/python/sglang/multimodal_gen/configs/models/dits/flux.py index f50c9b1e671f..f4c62d21341c 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/flux.py +++ b/python/sglang/multimodal_gen/configs/models/dits/flux.py @@ -21,8 +21,28 @@ class FluxArchConfig(DiTArchConfig): guidance_embeds: bool = False axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".to_added_qkv", ".add_q_proj", "q"), + (".to_added_qkv", ".add_k_proj", "k"), + (".to_added_qkv", ".add_v_proj", "v"), + ] + ) + param_names_mapping: dict = field( default_factory=lambda: { + # QKV fusion mappings + r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), + r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), + r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), + r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), + r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), + r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), + # General transformer prefix removal r"transformer\.(\w*)\.(.*)$": r"\1.\2", } ) diff --git a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py index f9003264b721..782c399734c4 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py @@ -21,8 +21,28 @@ class QwenImageArchConfig(DiTArchConfig): guidance_embeds: bool = False axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + (".to_added_qkv", ".add_q_proj", "q"), + (".to_added_qkv", ".add_k_proj", "k"), + (".to_added_qkv", ".add_v_proj", "v"), + ] + ) + param_names_mapping: dict = field( default_factory=lambda: { + # QKV fusion mappings + r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), + r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), + r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), + r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), + r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), + r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), + # LoRA mappings r"^(transformer_blocks\.\d+\.attn\..*\.lora_[AB])\.default$": r"\1", } ) From fe3a19a91e57734fa19ebd80c0c0977215383dc2 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 10:37:08 +0800 Subject: [PATCH 05/13] ud --- .../multimodal_gen/configs/models/dits/flux.py | 16 ++++++++-------- .../configs/models/dits/qwenimage.py | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/models/dits/flux.py b/python/sglang/multimodal_gen/configs/models/dits/flux.py index f4c62d21341c..ef12b6ab1cb0 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/flux.py +++ b/python/sglang/multimodal_gen/configs/models/dits/flux.py @@ -35,15 +35,15 @@ class FluxArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { - # QKV fusion mappings - r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), - r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), - r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), - r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), - r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), - r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), + # QKV fusion mappings - must come before general transformer mapping + r"^(.+)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), + r"^(.+)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), + r"^(.+)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), + r"^(.+)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), + r"^(.+)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), + r"^(.+)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), # General transformer prefix removal - r"transformer\.(\w*)\.(.*)$": r"\1.\2", + r"^transformer\.(\w+)\.(.+)$": r"\1.\2", } ) diff --git a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py index 782c399734c4..7d20a3ef697c 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py @@ -35,13 +35,13 @@ class QwenImageArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { - # QKV fusion mappings - r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), - r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), - r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), - r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), - r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), - r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), + # QKV fusion mappings - must come before LoRA mapping + r"^(.+)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), + r"^(.+)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), + r"^(.+)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), + r"^(.+)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), + r"^(.+)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), + r"^(.+)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), # LoRA mappings r"^(transformer_blocks\.\d+\.attn\..*\.lora_[AB])\.default$": r"\1", } From 286968438e20e5c2858ff2e7ff386ec69a858647 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 10:43:02 +0800 Subject: [PATCH 06/13] ud --- .../multimodal_gen/configs/models/dits/flux.py | 17 ++++++++--------- .../configs/models/dits/qwenimage.py | 14 +++++++------- .../multimodal_gen/runtime/models/dits/flux.py | 2 ++ .../runtime/models/dits/flux_2.py | 2 ++ .../runtime/models/dits/qwen_image.py | 2 ++ 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/models/dits/flux.py b/python/sglang/multimodal_gen/configs/models/dits/flux.py index ef12b6ab1cb0..ee5a9867a7dd 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/flux.py +++ b/python/sglang/multimodal_gen/configs/models/dits/flux.py @@ -35,15 +35,14 @@ class FluxArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { - # QKV fusion mappings - must come before general transformer mapping - r"^(.+)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), - r"^(.+)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), - r"^(.+)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), - r"^(.+)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), - r"^(.+)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), - r"^(.+)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), - # General transformer prefix removal - r"^transformer\.(\w+)\.(.+)$": r"\1.\2", + # QKV fusion mappings + r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), + r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), + r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), + r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), + r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), + r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), + r"transformer\.(\w*)\.(.*)$": r"\1.\2", } ) diff --git a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py index 7d20a3ef697c..782c399734c4 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py @@ -35,13 +35,13 @@ class QwenImageArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { - # QKV fusion mappings - must come before LoRA mapping - r"^(.+)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), - r"^(.+)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), - r"^(.+)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), - r"^(.+)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), - r"^(.+)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), - r"^(.+)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), + # QKV fusion mappings + r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), + r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), + r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), + r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), + r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), + r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), # LoRA mappings r"^(transformer_blocks\.\d+\.attn\..*\.lora_[AB])\.default$": r"\1", } diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 2228bfeadc0a..ee6e9e7d2158 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -400,6 +400,8 @@ class FluxTransformer2DModel(CachableDiT): Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ """ + param_names_mapping = FluxConfig().arch_config.param_names_mapping + def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: super().__init__(config=config, hf_config=hf_config) self.config = config.arch_config diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index 507a2bb297b2..cb13b130fda9 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -608,6 +608,8 @@ class Flux2Transformer2DModel(CachableDiT): """ + param_names_mapping = FluxConfig().arch_config.param_names_mapping + def __init__(self, config: FluxConfig, hf_config: dict[str, Any]): super().__init__(config=config, hf_config=hf_config) patch_size: int = config.patch_size diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index ff69b0c29643..837f6b14c45f 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -533,6 +533,8 @@ class QwenImageTransformer2DModel(CachableDiT): _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["QwenImageTransformerBlock"] + param_names_mapping = QwenImageDitConfig().arch_config.param_names_mapping + def __init__( self, config: QwenImageDitConfig, From b5399267108d8235bfa888e2bcbba487a4c9bef6 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 10:56:26 +0800 Subject: [PATCH 07/13] ud --- python/sglang/multimodal_gen/runtime/models/dits/flux_2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index cb13b130fda9..64025ce2ba73 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -43,12 +43,12 @@ def _get_qkv_projections( attn: "Flux2Attention", hidden_states, encoder_hidden_states=None ): - qkv = attn.to_qkv(hidden_states) + qkv, _ = attn.to_qkv(hidden_states) query, key, value = qkv.chunk(3, dim=-1) encoder_query = encoder_key = encoder_value = None if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: - added_qkv = attn.to_added_qkv(encoder_hidden_states) + added_qkv, _ = attn.to_added_qkv(encoder_hidden_states) encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value From f3590d13f8ecdde08ecc2f336d096c8d7ceef4ba Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 11:13:38 +0800 Subject: [PATCH 08/13] ud --- python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 837f6b14c45f..8715af40ab92 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -268,7 +268,7 @@ def __init__( hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, - bias=False, + bias=True, ) if self.qk_norm: From 84c0007fbcaf3a98ca6c0ff2b6854465b828eb62 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 11:19:02 +0800 Subject: [PATCH 09/13] ud --- .../multimodal_gen/configs/models/dits/qwenimage.py | 9 +++++---- .../multimodal_gen/runtime/models/dits/qwen_image.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py index 782c399734c4..b43810d24d6d 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py @@ -35,10 +35,11 @@ class QwenImageArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { - # QKV fusion mappings - r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), - r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), - r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), + # QKV fusion mappings - only weight, no bias for to_qkv + r"(.*)\.to_q\.weight$": (r"\1.to_qkv.weight", 0, 3), + r"(.*)\.to_k\.weight$": (r"\1.to_qkv.weight", 1, 3), + r"(.*)\.to_v\.weight$": (r"\1.to_qkv.weight", 2, 3), + # to_added_qkv has both weight and bias r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 8715af40ab92..837f6b14c45f 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -268,7 +268,7 @@ def __init__( hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, - bias=True, + bias=False, ) if self.qk_norm: From ab9577c4c0947a47ec56db7fbdc87f25efb01ac3 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 11:23:26 +0800 Subject: [PATCH 10/13] ud --- .../multimodal_gen/configs/models/dits/qwenimage.py | 9 ++++----- .../multimodal_gen/runtime/models/dits/qwen_image.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py index b43810d24d6d..782c399734c4 100644 --- a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py +++ b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py @@ -35,11 +35,10 @@ class QwenImageArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { - # QKV fusion mappings - only weight, no bias for to_qkv - r"(.*)\.to_q\.weight$": (r"\1.to_qkv.weight", 0, 3), - r"(.*)\.to_k\.weight$": (r"\1.to_qkv.weight", 1, 3), - r"(.*)\.to_v\.weight$": (r"\1.to_qkv.weight", 2, 3), - # to_added_qkv has both weight and bias + # QKV fusion mappings + r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3), + r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3), + r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3), r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3), r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3), r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3), diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 837f6b14c45f..8715af40ab92 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -268,7 +268,7 @@ def __init__( hidden_size=dim, head_size=head_dim, total_num_heads=num_heads, - bias=False, + bias=True, ) if self.qk_norm: From 4cd6f703396011b8db14b5ae05ad1e8fbd9277bc Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 11:25:58 +0800 Subject: [PATCH 11/13] clean --- .../sglang/multimodal_gen/runtime/loader/component_loader.py | 4 ---- python/sglang/multimodal_gen/runtime/models/dits/flux.py | 4 ---- python/sglang/multimodal_gen/runtime/models/dits/flux_2.py | 4 ---- .../sglang/multimodal_gen/runtime/models/dits/qwen_image.py | 4 ---- 4 files changed, 16 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loader.py index 1b7888208a45..db71023a619d 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loader.py @@ -692,10 +692,6 @@ def load_customized( model = model.eval() - if hasattr(model, "fuse_qkv_projections"): - logger.info("Fusing QKV projections for better performance") - model.fuse_qkv_projections() - return model diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index ee6e9e7d2158..1979ba7fe6dd 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -460,10 +460,6 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: bias=True, ) - def fuse_qkv_projections(self): - # QKV projections are already fused in __init__, this method is kept for compatibility - pass - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for q/k/v -> qkv fusion.""" stacked_params_mapping = [ diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index 64025ce2ba73..eb72b2073f45 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -701,10 +701,6 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]): self.gradient_checkpointing = False - def fuse_qkv_projections(self): - # QKV projections are already fused in __init__, this method is kept for compatibility - pass - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for q/k/v -> qkv fusion.""" stacked_params_mapping = [ diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index 8715af40ab92..ab3d5e9ead3c 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -581,10 +581,6 @@ def __init__( self.inner_dim, patch_size * patch_size * self.out_channels, bias=True ) - def fuse_qkv_projections(self): - # QKV projections are already fused in __init__, this method is kept for compatibility - pass - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights with mapping for q/k/v -> qkv fusion.""" stacked_params_mapping = [ From 9dfa9570a61fb80a3a5cc8da203f674e95642d14 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 14 Dec 2025 22:46:35 +0800 Subject: [PATCH 12/13] dead code --- .../runtime/models/dits/flux.py | 44 ------------------- .../runtime/models/dits/flux_2.py | 44 ------------------- .../runtime/models/dits/qwen_image.py | 44 ------------------- 3 files changed, 132 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index 1979ba7fe6dd..436e60667ecf 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -46,7 +45,6 @@ NDRotaryEmbedding, _apply_rotary_emb, ) -from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, @@ -460,48 +458,6 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: bias=True, ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """Load weights with mapping for q/k/v -> qkv fusion.""" - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), - (".to_added_qkv", ".add_q_proj", "q"), - (".to_added_qkv", ".add_k_proj", "k"), - (".to_added_qkv", ".add_v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - # Handle q/k/v -> qkv mapping - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - name = name.replace(weight_name, param_name) - - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(name) - break - else: - # Use default weight loader for all other parameters - if name in params_dict: - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params - def forward( self, hidden_states: torch.Tensor, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index eb72b2073f45..c64e47fddf26 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple import torch @@ -29,7 +28,6 @@ NDRotaryEmbedding, _apply_rotary_emb, ) -from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import ( AttentionBackendEnum, @@ -701,48 +699,6 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]): self.gradient_checkpointing = False - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """Load weights with mapping for q/k/v -> qkv fusion.""" - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), - (".to_added_qkv", ".add_q_proj", "q"), - (".to_added_qkv", ".add_k_proj", "k"), - (".to_added_qkv", ".add_v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - # Handle q/k/v -> qkv mapping - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - name = name.replace(weight_name, param_name) - - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(name) - break - else: - # Use default weight loader for all other parameters - if name in params_dict: - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params - def forward( self, hidden_states: torch.Tensor, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py index ab3d5e9ead3c..f8acd57bd56e 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 import functools -from collections.abc import Iterable from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -25,7 +24,6 @@ apply_rotary_embedding, fuse_scale_shift_kernel, ) -from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -581,48 +579,6 @@ def __init__( self.inner_dim, patch_size * patch_size * self.out_channels, bias=True ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """Load weights with mapping for q/k/v -> qkv fusion.""" - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), - (".to_added_qkv", ".add_q_proj", "q"), - (".to_added_qkv", ".add_k_proj", "k"), - (".to_added_qkv", ".add_v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - # Handle q/k/v -> qkv mapping - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - name = name.replace(weight_name, param_name) - - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, shard_id) - loaded_params.add(name) - break - else: - # Use default weight loader for all other parameters - if name in params_dict: - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - - return loaded_params - def forward( self, hidden_states: torch.Tensor, From bf3499695d6d604572311fa23ee9f02db9f0cc7b Mon Sep 17 00:00:00 2001 From: Mick Date: Sun, 14 Dec 2025 23:35:29 +0800 Subject: [PATCH 13/13] upd --- .../sglang/multimodal_gen/test/server/perf_baselines.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/test/server/perf_baselines.json b/python/sglang/multimodal_gen/test/server/perf_baselines.json index f2890284c5c1..e93a9d1e4926 100644 --- a/python/sglang/multimodal_gen/test/server/perf_baselines.json +++ b/python/sglang/multimodal_gen/test/server/perf_baselines.json @@ -584,7 +584,7 @@ "48": 247.48, "49": 247.54 }, - "expected_e2e_ms": 16563.83, + "expected_e2e_ms": 18382.19, "expected_avg_denoise_ms": 260.76, "expected_median_denoise_ms": 247.84 }, @@ -706,7 +706,7 @@ "TimestepPreparationStage": 2.9, "LatentPreparationStage": 1.25, "ImageVAEEncodingStage": 1655.89, - "DenoisingStage": 100544.98, + "DenoisingStage": 106972.82, "DecodingStage": 1355.52, "per_frame_generation": null }, @@ -753,7 +753,7 @@ "39": 1599.78 }, "expected_e2e_ms": 123182.9887, - "expected_avg_denoise_ms": 2513.52, + "expected_avg_denoise_ms": 2831.00, "expected_median_denoise_ms": 1600.09 }, "wan2_1_i2v_14b_480P_2gpu": {