Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6f48bbd
add moe_wna16_marlin_gemm_v2
BBuf Nov 29, 2025
eeea208
Revert "add moe_wna16_marlin_gemm_v2"
BBuf Nov 29, 2025
83b3a76
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Nov 29, 2025
ef908c7
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Nov 29, 2025
ce08aed
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Nov 30, 2025
8a23117
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Nov 30, 2025
f8314e1
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 1, 2025
7f4c11b
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 1, 2025
6bdf938
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 3, 2025
e7335fc
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 3, 2025
418c633
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 4, 2025
4e9009b
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 5, 2025
6c61160
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 5, 2025
66fb578
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 5, 2025
dac9554
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 5, 2025
6c5e299
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 5, 2025
9a9b300
Merge branch 'main' of github.com:sgl-project/sglang
BBuf Dec 5, 2025
4b553e2
Add QKV fusion optimization for Flux models
BBuf Dec 5, 2025
210dd8d
fix ci
BBuf Dec 6, 2025
a45c708
Merge branch 'main' into add_qkv_fusion
mickqian Dec 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,11 @@ def load_customized(
), "Model dtype does not match default dtype"

model = model.eval()

if hasattr(model, "fuse_qkv_projections"):
logger.info("Fusing QKV projections for better performance")
model.fuse_qkv_projections()

return model


Expand Down
77 changes: 73 additions & 4 deletions python/sglang/multimodal_gen/runtime/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states
def _get_fused_projections(
attn: "FluxAttention", hidden_states, encoder_hidden_states=None
):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
qkv, _ = attn.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)

encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(
encoder_hidden_states
).chunk(3, dim=-1)
added_qkv, _ = attn.to_added_qkv(encoder_hidden_states)
encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1)

return query, key, value, encoder_query, encoder_key, encoder_value

Expand All @@ -89,6 +89,7 @@ def _get_qkv_projections(


class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_supports_qkv_fusion = True

def __init__(
self,
Expand Down Expand Up @@ -161,6 +162,61 @@ def __init__(
},
)

self.fused_projections = False

@torch.no_grad()
def fuse_projections(self):
if self.fused_projections:
return

device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype

concatenated_weights = torch.cat(
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_qkv = ReplicatedLinear(in_features, out_features, bias=self.use_bias)
self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype)
if self.use_bias:
concatenated_bias = torch.cat(
[self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]
)
self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype)

if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat(
[
self.add_q_proj.weight.data,
self.add_k_proj.weight.data,
self.add_v_proj.weight.data,
]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_qkv = ReplicatedLinear(
in_features, out_features, bias=self.added_proj_bias
)
self.to_added_qkv.weight.data = concatenated_weights.to(
device=device, dtype=dtype
)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[
self.add_q_proj.bias.data,
self.add_k_proj.bias.data,
self.add_v_proj.bias.data,
]
)
self.to_added_qkv.bias.data = concatenated_bias.to(
device=device, dtype=dtype
)

self.fused_projections = True
Comment on lines +168 to +218
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for creating the fused to_qkv and to_added_qkv layers is very similar and contains duplicated code. This can be refactored into a helper method to improve maintainability and reduce redundancy. Additionally, after fusing the projection layers, the original layers (to_q, to_k, to_v, etc.) are no longer needed and can be deleted to free up GPU memory. This is an important optimization, especially for large models.

    def fuse_projections(self):
        if self.fused_projections:
            return

        device = self.to_q.weight.data.device
        dtype = self.to_q.weight.data.dtype

        def _fuse_linear(q_proj, k_proj, v_proj, use_bias):
            concatenated_weights = torch.cat(
                [q_proj.weight.data, k_proj.weight.data, v_proj.weight.data]
            )
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]

            fused_layer = ReplicatedLinear(in_features, out_features, bias=use_bias)
            fused_layer.weight.data = concatenated_weights.to(device=device, dtype=dtype)
            if use_bias:
                concatenated_bias = torch.cat(
                    [q_proj.bias.data, k_proj.bias.data, v_proj.bias.data]
                )
                fused_layer.bias.data = concatenated_bias.to(device=device, dtype=dtype)
            return fused_layer

        self.to_qkv = _fuse_linear(self.to_q, self.to_k, self.to_v, self.use_bias)
        del self.to_q, self.to_k, self.to_v

        if self.added_kv_proj_dim is not None:
            self.to_added_qkv = _fuse_linear(
                self.add_q_proj, self.add_k_proj, self.add_v_proj, self.added_proj_bias
            )
            del self.add_q_proj, self.add_k_proj, self.add_v_proj

        self.fused_projections = True


def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -473,6 +529,19 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None:
bias=True,
)

def fuse_qkv_projections(self):
for block in self.transformer_blocks:
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

for block in self.single_transformer_blocks:
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

Comment on lines +532 to +544
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loops over self.transformer_blocks and self.single_transformer_blocks are identical. You can combine them to make the code more concise and avoid repetition.

    def fuse_qkv_projections(self):
        for block in list(self.transformer_blocks) + list(self.single_transformer_blocks):
            if hasattr(block.attn, "fuse_projections"):
                block.attn.fuse_projections()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
79 changes: 74 additions & 5 deletions python/sglang/multimodal_gen/runtime/models/dits/flux_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_state
def _get_fused_projections(
attn: "Flux2Attention", hidden_states, encoder_hidden_states=None
):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
qkv = attn.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)

encoder_query = encoder_key = encoder_value = (None,)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(
encoder_hidden_states
).chunk(3, dim=-1)
added_qkv = attn.to_added_qkv(encoder_hidden_states)
encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1)

return query, key, value, encoder_query, encoder_key, encoder_value

Expand Down Expand Up @@ -114,6 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
_supports_qkv_fusion = True

def __init__(
self,
Expand Down Expand Up @@ -182,6 +183,61 @@ def __init__(
},
)

self.fused_projections = False

@torch.no_grad()
def fuse_projections(self):
if self.fused_projections:
return

device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype

concatenated_weights = torch.cat(
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias)
self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype)
if self.use_bias:
concatenated_bias = torch.cat(
[self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]
)
self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype)

if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat(
[
self.add_q_proj.weight.data,
self.add_k_proj.weight.data,
self.add_v_proj.weight.data,
]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_added_qkv = torch.nn.Linear(
in_features, out_features, bias=self.added_proj_bias
)
self.to_added_qkv.weight.data = concatenated_weights.to(
device=device, dtype=dtype
)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[
self.add_q_proj.bias.data,
self.add_k_proj.bias.data,
self.add_v_proj.bias.data,
]
)
self.to_added_qkv.bias.data = concatenated_bias.to(
device=device, dtype=dtype
)

self.fused_projections = True
Comment on lines +189 to +239
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the FluxAttention implementation, the logic for creating the fused to_qkv and to_added_qkv layers is duplicated. This can be refactored into a helper method to improve maintainability. Also, consider deleting the original projection layers after fusion to free up GPU memory, which is an important optimization.

    @torch.no_grad()
    def fuse_projections(self):
        if self.fused_projections:
            return

        device = self.to_q.weight.data.device
        dtype = self.to_q.weight.data.dtype

        def _fuse_linear(q_proj, k_proj, v_proj, use_bias):
            concatenated_weights = torch.cat(
                [q_proj.weight.data, k_proj.weight.data, v_proj.weight.data]
            )
            in_features = concatenated_weights.shape[1]
            out_features = concatenated_weights.shape[0]

            fused_layer = torch.nn.Linear(in_features, out_features, bias=use_bias)
            fused_layer.weight.data = concatenated_weights.to(device=device, dtype=dtype)
            if use_bias:
                concatenated_bias = torch.cat(
                    [q_proj.bias.data, k_proj.bias.data, v_proj.bias.data]
                )
                fused_layer.bias.data = concatenated_bias.to(device=device, dtype=dtype)
            return fused_layer

        self.to_qkv = _fuse_linear(self.to_q, self.to_k, self.to_v, self.use_bias)
        del self.to_q, self.to_k, self.to_v

        if self.added_kv_proj_dim is not None:
            self.to_added_qkv = _fuse_linear(
                self.add_q_proj, self.add_k_proj, self.add_v_proj, self.added_proj_bias
            )
            del self.add_q_proj, self.add_k_proj, self.add_v_proj

        self.fused_projections = True


def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -728,6 +784,19 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]):

self.gradient_checkpointing = False

def fuse_qkv_projections(self):
for block in self.transformer_blocks:
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

for block in self.single_transformer_blocks:
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

Comment on lines +787 to +799
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loops over self.transformer_blocks and self.single_transformer_blocks are identical. You can combine them to make the code more concise and avoid repetition.

    def fuse_qkv_projections(self):
        for block in list(self.transformer_blocks) + list(self.single_transformer_blocks):
            if hasattr(block.attn, "fuse_projections"):
                block.attn.fuse_projections()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Loading