Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/user_guide/diffusion/quantization/fp8.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ The available `ignored_layers` names depend on the model architecture (e.g., `to
| Flux | `black-forest-labs/FLUX.1-dev` | All layers | None |
| HunyuanImage-3 | `tencent/HunyuanImage3` | All layers | None |
| HunyuanVideo-1.5 | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `720p_t2v`, `480p_i2v` | All layers | None |
| GLM-Image | `zai-org/GLM-Image` | All layers | None |
| Helios | `BestWishYsh/Helios-Base`, `BestWishYsh/Helios-Mid`, `BestWishYsh/Helios-Distilled` | All layers | None |

## Combining with Other Features
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":

# Backwards-compatibility: map "quantization" to "quantization_config"
# so callers using the old field name still work.
if "quantization" in kwargs and "quantization_config" not in kwargs:
if "quantization" in kwargs and kwargs.get("quantization_config") is None:
kwargs["quantization_config"] = kwargs.pop("quantization")
else:
kwargs.pop("quantization", None)
Expand Down
43 changes: 37 additions & 6 deletions vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
from collections.abc import Iterable
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING, Any

import torch
import torch.nn as nn
Expand All @@ -19,6 +19,9 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig

from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.cache.base import CachedTransformer
Expand Down Expand Up @@ -465,6 +468,7 @@ def __init__(
parallel_config: DiffusionParallelConfig | None = None,
out_bias: bool = True,
eps: float = 1e-5,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.dim = dim
Expand All @@ -480,6 +484,7 @@ def __init__(
total_num_kv_heads=num_heads,
bias=True,
return_bias=False,
quant_config=quant_config,
)

# QK normalization (LayerNorm, not RMSNorm for GLM-Image)
Expand All @@ -495,6 +500,7 @@ def __init__(
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
),
nn.Dropout(0.0),
]
Expand Down Expand Up @@ -656,6 +662,7 @@ def __init__(
*,
approximate: str = "none",
bias: bool = True,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.proj = ColumnParallelLinear(
Expand All @@ -664,6 +671,7 @@ def __init__(
bias=bias,
gather_output=False,
return_bias=False,
quant_config=quant_config,
)
self.approximate = approximate

Expand All @@ -679,6 +687,7 @@ def __init__(
dim_out: int,
*,
bias: bool = True,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.proj = ColumnParallelLinear(
Expand All @@ -687,6 +696,7 @@ def __init__(
bias=bias,
gather_output=False,
return_bias=False,
quant_config=quant_config,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -703,34 +713,37 @@ def __init__(
inner_dim: int | None = None,
bias: bool = True,
activation_fn: str = "gelu",
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
inner_dim = inner_dim or int(dim * mult)
dim_out = dim_out or dim

if activation_fn == "linear-silu":
layers: list[nn.Module] = [
ColumnParallelSiLU(dim, inner_dim, bias=bias),
ColumnParallelSiLU(dim, inner_dim, bias=bias, quant_config=quant_config),
nn.Identity(),
RowParallelLinear(
inner_dim,
dim_out,
bias=bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
),
]
else:
approximate = "tanh" if activation_fn == "gelu-approximate" else "none"
layers = [
ColumnParallelGELU(dim, inner_dim, approximate=approximate, bias=bias),
ColumnParallelGELU(dim, inner_dim, approximate=approximate, bias=bias, quant_config=quant_config),
nn.Identity(),
RowParallelLinear(
inner_dim,
dim_out,
bias=bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
),
]

Expand All @@ -752,6 +765,7 @@ def __init__(
attention_head_dim: int = 40,
time_embed_dim: int = 512,
ffn_hidden_dim: int | None = None,
quant_config: "QuantizationConfig | None" = None,
parallel_config: DiffusionParallelConfig | None = None,
) -> None:
super().__init__()
Expand All @@ -762,13 +776,20 @@ def __init__(
dim=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
quant_config=quant_config,
parallel_config=parallel_config,
)

# 2. Feedforward
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = GlmImageFeedForward(dim=dim, dim_out=dim, inner_dim=ffn_hidden_dim, activation_fn="gelu-approximate")
self.ff = GlmImageFeedForward(
dim=dim,
dim_out=dim,
inner_dim=ffn_hidden_dim,
activation_fn="gelu-approximate",
quant_config=quant_config,
)

def forward(
self,
Expand Down Expand Up @@ -879,6 +900,7 @@ class GlmImageTransformer2DModel(CachedTransformer):
def __init__(
self,
od_config: OmniDiffusionConfig,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()

Expand Down Expand Up @@ -932,11 +954,19 @@ def __init__(
# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
self.glyph_projector = GlmImageFeedForward(
dim=text_embed_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="gelu"
dim=text_embed_dim,
dim_out=inner_dim,
inner_dim=inner_dim,
activation_fn="gelu",
quant_config=quant_config,
)
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = GlmImageFeedForward(
dim=inner_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="linear-silu"
dim=inner_dim,
dim_out=inner_dim,
inner_dim=inner_dim,
activation_fn="linear-silu",
quant_config=quant_config,
)

# Prepare module for SP (encapsulates patch embedding and RoPE for _sp_plan)
Expand All @@ -958,6 +988,7 @@ def __init__(
attention_head_dim,
time_embed_dim,
ffn_hidden_dim=ffn_hidden_dim,
quant_config=quant_config,
parallel_config=self.parallel_config,
)
for _ in range(num_layers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,10 @@ def __init__(

# Load transformer (DiT)
logger.info("Loading GlmImageTransformer2DModel (DiT)...")
self.transformer = GlmImageTransformer2DModel(od_config=od_config)
self.transformer = GlmImageTransformer2DModel(
od_config=od_config,
quant_config=od_config.quantization_config,
)

# Weight sources for DiT loading
self.weights_sources = [
Expand Down
3 changes: 3 additions & 0 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,10 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
if lora_scale is not None:
if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None:
cfg.engine_args.lora_scale = lora_scale
# Prefer explicit quantization_config; fallback to legacy --quantization.
quantization_config = kwargs.get("quantization_config")
if quantization_config is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add this logic

Copy link
Copy Markdown
Contributor Author

@RuixiangMa RuixiangMa Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsyncOmniEngine._resolve_stage_configs() follows the multi-stage config path, where top-level kwargs quantization args are not automatically propagated to each diffusion stage’s engine_args.

It is not strictly required if we decide to enforce quantization_config only, but without it, legacy --quantization can be silently ineffective in diffusion stage.

quantization_config = kwargs.get("quantization")
if quantization_config is not None:
if (
not hasattr(cfg.engine_args, "quantization_config")
Expand Down
Loading