diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py index 680b8bfbbed..297c6267515 100644 --- a/vllm_omni/diffusion/models/flux/flux_transformer.py +++ b/vllm_omni/diffusion/models/flux/flux_transformer.py @@ -381,7 +381,9 @@ def __init__( super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) - self.norm = AdaLayerNormZeroSingle(dim, quant_config=quant_config, prefix=f"{prefix}.norm") + # Modulation linear kept full precision; shift/scale/gate outputs + # are multiplied into the residual stream every block (see #2728). + self.norm = AdaLayerNormZeroSingle(dim, quant_config=None, prefix=f"{prefix}.norm") self.proj_mlp = ReplicatedLinear( dim, self.mlp_hidden_dim, @@ -563,13 +565,16 @@ def __init__( self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = nn.Linear(in_channels, self.inner_dim) + # Dual-stream blocks kept full precision — FP8 on their joint + # attention path causes noise on FLUX (#2728). Single-stream + # blocks (38 vs 19) still get FP8 for memory savings. self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, - quant_config=quant_config, + quant_config=None, prefix=f"transformer_blocks.{i}", ) for i in range(num_layers) @@ -589,12 +594,13 @@ def __init__( ] ) + # Final modulation feeds proj_out; keep full precision (see #2728). self.norm_out = AdaLayerNormContinuous( self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6, - quant_config=quant_config, + quant_config=None, prefix="norm_out", ) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 9f16d8808c8..88a66d7f6b0 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -169,12 +169,15 @@ def __init__( self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + # Time embedding MLP is kept full precision (quant_config=None) — + # small layers that feed per-block modulation; precision-sensitive + # (see #2728). self.timestep_embedder.linear_1 = ReplicatedLinear( 256, embedding_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="timestep_embedder.linear_1", ) self.timestep_embedder.linear_2 = ReplicatedLinear( @@ -182,7 +185,7 @@ def __init__( embedding_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="timestep_embedder.linear_2", ) self.use_additional_t_cond = use_additional_t_cond @@ -701,7 +704,10 @@ def __init__( self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim - # Image processing modules + # Image processing modules. + # Modulation linear is kept full precision (quant_config=None) — it + # produces shift/scale/gate values that are precision-sensitive + # (see #2728). self.img_mod = nn.Sequential( nn.SiLU(), ReplicatedLinear( @@ -709,7 +715,7 @@ def __init__( 6 * dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="img_mod.1", ), ) @@ -725,7 +731,7 @@ def __init__( self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) self.img_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="img_mlp") - # Text processing modules + # Text processing modules. self.txt_mod = nn.Sequential( nn.SiLU(), ReplicatedLinear( @@ -733,7 +739,7 @@ def __init__( 6 * dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="txt_mod.1", ), ) @@ -963,12 +969,14 @@ def __init__( self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + # Entry projections (image/text) are kept full precision — + # small sensitive layers at the network boundary (see #2728). self.img_in = ReplicatedLinear( in_channels, self.inner_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="img_in", ) self.txt_in = ReplicatedLinear( @@ -976,7 +984,7 @@ def __init__( self.inner_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="txt_in", ) @@ -993,13 +1001,16 @@ def __init__( ] ) + # Final modulation and output projection are kept full precision — + # they produce the output latent and are precision-sensitive + # (see #2728). self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.norm_out.linear = ReplicatedLinear( self.inner_dim, 2 * self.inner_dim, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="norm_out.linear", ) self.proj_out = ReplicatedLinear( @@ -1007,7 +1018,7 @@ def __init__( patch_size * patch_size * self.out_channels, bias=True, return_bias=False, - quant_config=quant_config, + quant_config=None, prefix="proj_out", ) diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 3ffad221ba9..c36ea746654 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -214,12 +214,14 @@ def __init__( super().__init__() if mid_size is None: mid_size = out_size + # Time embedding MLP is kept full precision (quant_config=None) — + # small layers that feed adaLN; precision-sensitive (see #2728). self.mlp = nn.Sequential( ReplicatedLinear( frequency_embedding_size, mid_size, bias=True, - quant_config=quant_config, + quant_config=None, return_bias=False, ), nn.SiLU(), @@ -227,7 +229,7 @@ def __init__( mid_size, out_size, bias=True, - quant_config=quant_config, + quant_config=None, return_bias=False, ), ) @@ -426,9 +428,16 @@ def __init__( self.modulation = modulation if modulation: + # Modulation linear is kept at full precision (quant_config=None) + # — it produces scale/gate values that are precision-sensitive + # (see #2728, mirrors OmniGen2 fix). self.adaLN_modulation = nn.Sequential( ReplicatedLinear( - min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config + min(dim, ADALN_EMBED_DIM), + 4 * dim, + bias=True, + quant_config=None, + return_bias=False, ), ) @@ -485,14 +494,24 @@ class FinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # Final output projection and its modulation are precision-sensitive + # (produce the output latent); keep at full precision (see #2728). self.linear = ReplicatedLinear( - hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False + hidden_size, + out_channels, + bias=True, + quant_config=None, + return_bias=False, ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), ReplicatedLinear( - min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False + min(hidden_size, ADALN_EMBED_DIM), + hidden_size, + bias=True, + quant_config=None, + return_bias=False, ), ) @@ -673,11 +692,13 @@ def __init__( all_x_embedder = {} all_final_layer = {} for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + # x_embedder (patch embed) is a small precision-sensitive entry + # layer; keep full precision (see #2728). x_embedder = ReplicatedLinear( f_patch_size * patch_size * patch_size * in_channels, dim, bias=True, - quant_config=quant_config, + quant_config=None, return_bias=False, ) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder @@ -720,9 +741,17 @@ def __init__( ] ) self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024, quant_config=quant_config) + # Caption embedder maps text features -> hidden; keep full precision + # (see #2728). self.cap_embedder = nn.Sequential( RMSNorm(cap_feat_dim, eps=norm_eps), - ReplicatedLinear(cap_feat_dim, dim, bias=True, return_bias=False, quant_config=quant_config), + ReplicatedLinear( + cap_feat_dim, + dim, + bias=True, + quant_config=None, + return_bias=False, + ), ) self.x_pad_token = nn.Parameter(torch.empty((1, dim)))