diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index a14e875c06c..d32a6d8acac 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) @@ -157,21 +158,12 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - self.gate_proj = ColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, - intermediate_size, - bias=False, - gather_output=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_proj", - ) - self.up_proj = ColumnParallelLinear( - hidden_size, - intermediate_size, + [intermediate_size, intermediate_size], bias=False, - gather_output=False, quant_config=quant_config, - prefix=f"{prefix}.up_proj", + prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, @@ -186,8 +178,8 @@ def __init__( self.act_fn = nn.SiLU() def forward(self, x): - gate, _ = self.gate_proj(x) - up, _ = self.up_proj(x) + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.chunk(2, dim=-1) x = self.act_fn(gate) * up x, _ = self.down_proj(x) return x @@ -929,13 +921,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), - # MLP gate/up projections — the DiT uses separate - # ColumnParallelLinear layers (no fused gate_up_proj), but - # these entries are needed so that DiffusionLoRAManager can - # derive the packed→sublayer mapping for LoRA checkpoints - # that store weights under fused gate_up_proj keys. - # The weight loader gracefully falls through to the - # non-stacked path when the fused parameter doesn't exist. + # MLP gate/up projections — fused into MergedColumnParallelLinear. + # HF checkpoints store separate gate_proj / up_proj weights; + # these entries remap them to the fused gate_up_proj parameter. (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index 3e053cbda50..84f177e01a9 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -675,6 +675,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: (".qkv_proj_moe_gen", ".q_proj_moe_gen"), (".qkv_proj_moe_gen", ".k_proj_moe_gen"), (".qkv_proj_moe_gen", ".v_proj_moe_gen"), + (".gate_up_proj", ".gate_proj"), + (".gate_up_proj", ".up_proj"), ] stacked_source_names: set[str] = set() for name in list(allowed):