diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loader.py index 5c311789e962..9de8eb586469 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loader.py @@ -691,6 +691,11 @@ def load_customized( ), "Model dtype does not match default dtype" model = model.eval() + + if hasattr(model, "fuse_qkv_projections"): + logger.info("Fusing QKV projections for better performance") + model.fuse_qkv_projections() + return model diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py index fbb752429a97..d2a2f0304e04 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -69,13 +69,13 @@ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states def _get_fused_projections( attn: "FluxAttention", hidden_states, encoder_hidden_states=None ): - query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + qkv, _ = attn.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) encoder_query = encoder_key = encoder_value = None if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv( - encoder_hidden_states - ).chunk(3, dim=-1) + added_qkv, _ = attn.to_added_qkv(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value @@ -89,6 +89,7 @@ def _get_qkv_projections( class FluxAttention(torch.nn.Module, AttentionModuleMixin): + _supports_qkv_fusion = True def __init__( self, @@ -161,6 +162,61 @@ def __init__( }, ) + self.fused_projections = False + + @torch.no_grad() + def fuse_projections(self): + if self.fused_projections: + return + + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + concatenated_weights = torch.cat( + [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_qkv = ReplicatedLinear(in_features, out_features, bias=self.use_bias) + self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype) + if self.use_bias: + concatenated_bias = torch.cat( + [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] + ) + self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat( + [ + self.add_q_proj.weight.data, + self.add_k_proj.weight.data, + self.add_v_proj.weight.data, + ] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = ReplicatedLinear( + in_features, out_features, bias=self.added_proj_bias + ) + self.to_added_qkv.weight.data = concatenated_weights.to( + device=device, dtype=dtype + ) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [ + self.add_q_proj.bias.data, + self.add_k_proj.bias.data, + self.add_v_proj.bias.data, + ] + ) + self.to_added_qkv.bias.data = concatenated_bias.to( + device=device, dtype=dtype + ) + + self.fused_projections = True + def forward( self, x: torch.Tensor, @@ -473,6 +529,19 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: bias=True, ) + def fuse_qkv_projections(self): + for block in self.transformer_blocks: + if hasattr(block.attn, "fuse_projections") and getattr( + block.attn, "_supports_qkv_fusion", True + ): + block.attn.fuse_projections() + + for block in self.single_transformer_blocks: + if hasattr(block.attn, "fuse_projections") and getattr( + block.attn, "_supports_qkv_fusion", True + ): + block.attn.fuse_projections() + def forward( self, hidden_states: torch.Tensor, diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py index 3ff593f819b3..290765c93142 100644 --- a/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux_2.py @@ -52,13 +52,13 @@ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_state def _get_fused_projections( attn: "Flux2Attention", hidden_states, encoder_hidden_states=None ): - query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + qkv = attn.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) - encoder_query = encoder_key = encoder_value = (None,) + encoder_query = encoder_key = encoder_value = None if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv( - encoder_hidden_states - ).chunk(3, dim=-1) + added_qkv = attn.to_added_qkv(encoder_hidden_states) + encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value @@ -114,6 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Flux2Attention(torch.nn.Module, AttentionModuleMixin): + _supports_qkv_fusion = True def __init__( self, @@ -182,6 +183,61 @@ def __init__( }, ) + self.fused_projections = False + + @torch.no_grad() + def fuse_projections(self): + if self.fused_projections: + return + + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + concatenated_weights = torch.cat( + [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias) + self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype) + if self.use_bias: + concatenated_bias = torch.cat( + [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] + ) + self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat( + [ + self.add_q_proj.weight.data, + self.add_k_proj.weight.data, + self.add_v_proj.weight.data, + ] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = torch.nn.Linear( + in_features, out_features, bias=self.added_proj_bias + ) + self.to_added_qkv.weight.data = concatenated_weights.to( + device=device, dtype=dtype + ) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [ + self.add_q_proj.bias.data, + self.add_k_proj.bias.data, + self.add_v_proj.bias.data, + ] + ) + self.to_added_qkv.bias.data = concatenated_bias.to( + device=device, dtype=dtype + ) + + self.fused_projections = True + def forward( self, hidden_states: torch.Tensor, @@ -728,6 +784,19 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]): self.gradient_checkpointing = False + def fuse_qkv_projections(self): + for block in self.transformer_blocks: + if hasattr(block.attn, "fuse_projections") and getattr( + block.attn, "_supports_qkv_fusion", True + ): + block.attn.fuse_projections() + + for block in self.single_transformer_blocks: + if hasattr(block.attn, "fuse_projections") and getattr( + block.attn, "_supports_qkv_fusion", True + ): + block.attn.fuse_projections() + def forward( self, hidden_states: torch.Tensor,