Skip to content
12 changes: 9 additions & 3 deletions vllm_omni/diffusion/models/flux/flux_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ def __init__(
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)

self.norm = AdaLayerNormZeroSingle(dim, quant_config=quant_config, prefix=f"{prefix}.norm")
# Modulation linear kept full precision; shift/scale/gate outputs
# are multiplied into the residual stream every block (see #2728).
self.norm = AdaLayerNormZeroSingle(dim, quant_config=None, prefix=f"{prefix}.norm")
self.proj_mlp = ReplicatedLinear(
dim,
self.mlp_hidden_dim,
Expand Down Expand Up @@ -563,13 +565,16 @@ def __init__(
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)

# Dual-stream blocks kept full precision — FP8 on their joint
# attention path causes noise on FLUX (#2728). Single-stream
# blocks (38 vs 19) still get FP8 for memory savings.
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
quant_config=quant_config,
quant_config=None,
prefix=f"transformer_blocks.{i}",
)
for i in range(num_layers)
Expand All @@ -589,12 +594,13 @@ def __init__(
]
)

# Final modulation feeds proj_out; keep full precision (see #2728).
self.norm_out = AdaLayerNormContinuous(
self.inner_dim,
self.inner_dim,
elementwise_affine=False,
eps=1e-6,
quant_config=quant_config,
quant_config=None,
prefix="norm_out",
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
Expand Down
31 changes: 21 additions & 10 deletions vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,23 @@ def __init__(

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
# Time embedding MLP is kept full precision (quant_config=None) —
# small layers that feed per-block modulation; precision-sensitive
# (see #2728).
self.timestep_embedder.linear_1 = ReplicatedLinear(
256,
embedding_dim,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="timestep_embedder.linear_1",
)
self.timestep_embedder.linear_2 = ReplicatedLinear(
embedding_dim,
embedding_dim,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="timestep_embedder.linear_2",
)
self.use_additional_t_cond = use_additional_t_cond
Expand Down Expand Up @@ -701,15 +704,18 @@ def __init__(
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim

# Image processing modules
# Image processing modules.
# Modulation linear is kept full precision (quant_config=None) — it
# produces shift/scale/gate values that are precision-sensitive
# (see #2728).
self.img_mod = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
dim,
6 * dim,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="img_mod.1",
),
)
Expand All @@ -725,15 +731,15 @@ def __init__(
self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="img_mlp")

# Text processing modules
# Text processing modules.
self.txt_mod = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
dim,
6 * dim,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="txt_mod.1",
),
)
Expand Down Expand Up @@ -963,20 +969,22 @@ def __init__(

self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)

# Entry projections (image/text) are kept full precision —
# small sensitive layers at the network boundary (see #2728).
self.img_in = ReplicatedLinear(
in_channels,
self.inner_dim,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="img_in",
)
self.txt_in = ReplicatedLinear(
joint_attention_dim,
self.inner_dim,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="txt_in",
)

Expand All @@ -993,21 +1001,24 @@ def __init__(
]
)

# Final modulation and output projection are kept full precision —
# they produce the output latent and are precision-sensitive
# (see #2728).
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out.linear = ReplicatedLinear(
self.inner_dim,
2 * self.inner_dim,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="norm_out.linear",
)
self.proj_out = ReplicatedLinear(
self.inner_dim,
patch_size * patch_size * self.out_channels,
bias=True,
return_bias=False,
quant_config=quant_config,
quant_config=None,
prefix="proj_out",
)

Expand Down
43 changes: 36 additions & 7 deletions vllm_omni/diffusion/models/z_image/z_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,22 @@ def __init__(
super().__init__()
if mid_size is None:
mid_size = out_size
# Time embedding MLP is kept full precision (quant_config=None) —
# small layers that feed adaLN; precision-sensitive (see #2728).
self.mlp = nn.Sequential(
ReplicatedLinear(
frequency_embedding_size,
mid_size,
bias=True,
quant_config=quant_config,
quant_config=None,
return_bias=False,
),
nn.SiLU(),
ReplicatedLinear(
mid_size,
out_size,
bias=True,
quant_config=quant_config,
quant_config=None,
return_bias=False,
),
)
Expand Down Expand Up @@ -426,9 +428,16 @@ def __init__(

self.modulation = modulation
if modulation:
# Modulation linear is kept at full precision (quant_config=None)
# — it produces scale/gate values that are precision-sensitive
# (see #2728, mirrors OmniGen2 fix).
self.adaLN_modulation = nn.Sequential(
ReplicatedLinear(
min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config
min(dim, ADALN_EMBED_DIM),
4 * dim,
bias=True,
quant_config=None,
return_bias=False,
),
)

Expand Down Expand Up @@ -485,14 +494,24 @@ class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# Final output projection and its modulation are precision-sensitive
# (produce the output latent); keep at full precision (see #2728).
self.linear = ReplicatedLinear(
hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False
hidden_size,
out_channels,
bias=True,
quant_config=None,
return_bias=False,
)

self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False
min(hidden_size, ADALN_EMBED_DIM),
hidden_size,
bias=True,
quant_config=None,
return_bias=False,
),
)

Expand Down Expand Up @@ -673,11 +692,13 @@ def __init__(
all_x_embedder = {}
all_final_layer = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
# x_embedder (patch embed) is a small precision-sensitive entry
# layer; keep full precision (see #2728).
x_embedder = ReplicatedLinear(
f_patch_size * patch_size * patch_size * in_channels,
dim,
bias=True,
quant_config=quant_config,
quant_config=None,
return_bias=False,
)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
Expand Down Expand Up @@ -720,9 +741,17 @@ def __init__(
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024, quant_config=quant_config)
# Caption embedder maps text features -> hidden; keep full precision
# (see #2728).
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
ReplicatedLinear(cap_feat_dim, dim, bias=True, return_bias=False, quant_config=quant_config),
ReplicatedLinear(
cap_feat_dim,
dim,
bias=True,
quant_config=None,
return_bias=False,
),
)

self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
Expand Down
Loading