Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/sglang/multimodal_gen/configs/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,27 @@ class FluxArchConfig(DiTArchConfig):
guidance_embeds: bool = False
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)

stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".to_qkv", ".to_q", "q"),
(".to_qkv", ".to_k", "k"),
(".to_qkv", ".to_v", "v"),
(".to_added_qkv", ".add_q_proj", "q"),
(".to_added_qkv", ".add_k_proj", "k"),
(".to_added_qkv", ".add_v_proj", "v"),
]
)

param_names_mapping: dict = field(
default_factory=lambda: {
# QKV fusion mappings
r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3),
r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3),
r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3),
r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3),
r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3),
r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3),
r"transformer\.(\w*)\.(.*)$": r"\1.\2",
}
)
Expand Down
20 changes: 20 additions & 0 deletions python/sglang/multimodal_gen/configs/models/dits/qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,28 @@ class QwenImageArchConfig(DiTArchConfig):
guidance_embeds: bool = False
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)

stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".to_qkv", ".to_q", "q"),
(".to_qkv", ".to_k", "k"),
(".to_qkv", ".to_v", "v"),
(".to_added_qkv", ".add_q_proj", "q"),
(".to_added_qkv", ".add_k_proj", "k"),
(".to_added_qkv", ".add_v_proj", "v"),
]
)

param_names_mapping: dict = field(
default_factory=lambda: {
# QKV fusion mappings
r"(.*)\.to_q\.(weight|bias)$": (r"\1.to_qkv.\2", 0, 3),
r"(.*)\.to_k\.(weight|bias)$": (r"\1.to_qkv.\2", 1, 3),
r"(.*)\.to_v\.(weight|bias)$": (r"\1.to_qkv.\2", 2, 3),
r"(.*)\.add_q_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 0, 3),
r"(.*)\.add_k_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 1, 3),
r"(.*)\.add_v_proj\.(weight|bias)$": (r"\1.to_added_qkv.\2", 2, 3),
# LoRA mappings
r"^(transformer_blocks\.\d+\.attn\..*\.lora_[AB])\.default$": r"\1",
}
)
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/multimodal_gen/docs/profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ By default, trace files are saved in the ./logs/ directory. The exact output fil
```bash
[mm-dd hh:mm:ss] Saving profiler traces to: /sgl-workspace/sglang/logs/mocked_fake_id_for_offline_generate-5_steps-global-rank0.trace.json.gz
```
{request_id}-{num_steps}_steps-global-rank{rank}.trace.json.gz
```

Example: `mocked_fake_id_for_offline_generate-5_steps-global-rank0.trace.json.gz`

### View Traces

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,6 @@ def load_customized(

model = model.eval()

if hasattr(model, "fuse_qkv_projections"):
logger.info("Fusing QKV projections for better performance")
model.fuse_qkv_projections()

return model


Expand Down
98 changes: 22 additions & 76 deletions python/sglang/multimodal_gen/runtime/models/dits/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@

# from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm as LayerNorm
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
)
from sglang.multimodal_gen.runtime.layers.mlp import MLP
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
NDRotaryEmbedding,
_apply_rotary_emb,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.models.dits.utils import (
delete_projection_layers,
fuse_linear_projections,
)
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
Expand All @@ -56,45 +55,21 @@
logger = init_logger(__name__) # pylint: disable=invalid-name


def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, _ = attn.to_q(hidden_states)
key, _ = attn.to_k(hidden_states)
value, _ = attn.to_v(hidden_states)

encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query, _ = attn.add_q_proj(encoder_hidden_states)
encoder_key, _ = attn.add_k_proj(encoder_hidden_states)
encoder_value, _ = attn.add_v_proj(encoder_hidden_states)

return query, key, value, encoder_query, encoder_key, encoder_value


def _get_fused_projections(
def _get_qkv_projections(
attn: "FluxAttention", hidden_states, encoder_hidden_states=None
):
qkv, _ = attn.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)

encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
added_qkv, _ = attn.to_added_qkv(encoder_hidden_states)
encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1)

return query, key, value, encoder_query, encoder_key, encoder_value


def _get_qkv_projections(
attn: "FluxAttention", hidden_states, encoder_hidden_states=None
):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)


class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_supports_qkv_fusion = True

def __init__(
self,
query_dim: int,
Expand Down Expand Up @@ -125,11 +100,15 @@ def __init__(
self.added_proj_bias = added_proj_bias

self.norm_q = RMSNorm(dim_head, eps=eps)

self.norm_k = RMSNorm(dim_head, eps=eps)
self.to_q = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)
self.to_k = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)
self.to_v = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)

# Use QKVParallelLinear for fused QKV projections
self.to_qkv = QKVParallelLinear(
hidden_size=query_dim,
head_size=dim_head,
total_num_heads=num_heads,
bias=bias,
)

if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
Expand All @@ -142,14 +121,12 @@ def __init__(
if added_kv_proj_dim is not None:
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
self.add_q_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_k_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_v_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
# Use QKVParallelLinear for added (encoder) QKV projections
self.to_added_qkv = QKVParallelLinear(
hidden_size=added_kv_proj_dim,
head_size=dim_head,
total_num_heads=num_heads,
bias=added_proj_bias,
)
self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias)

Expand All @@ -166,30 +143,6 @@ def __init__(
},
)

self.fused_projections = False

@torch.no_grad()
def fuse_projections(self):
if self.fused_projections:
return

self.to_qkv = fuse_linear_projections(
self.to_q, self.to_k, self.to_v, self.use_bias, ReplicatedLinear
)
delete_projection_layers(self, ["to_q", "to_k", "to_v"])

if self.added_kv_proj_dim is not None:
self.to_added_qkv = fuse_linear_projections(
self.add_q_proj,
self.add_k_proj,
self.add_v_proj,
self.added_proj_bias,
ReplicatedLinear,
)
delete_projection_layers(self, ["add_q_proj", "add_k_proj", "add_v_proj"])

self.fused_projections = True

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -445,6 +398,8 @@ class FluxTransformer2DModel(CachableDiT):
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
"""

param_names_mapping = FluxConfig().arch_config.param_names_mapping

def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None:
super().__init__(config=config, hf_config=hf_config)
self.config = config.arch_config
Expand Down Expand Up @@ -503,15 +458,6 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None:
bias=True,
)

def fuse_qkv_projections(self):
for block in list(self.transformer_blocks) + list(
self.single_transformer_blocks
):
if hasattr(block.attn, "fuse_projections") and getattr(
block.attn, "_supports_qkv_fusion", True
):
block.attn.fuse_projections()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Loading
Loading