diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md index 2c065546bf..95a62b05d7 100644 --- a/docs/user_guide/diffusion/quantization/fp8.md +++ b/docs/user_guide/diffusion/quantization/fp8.md @@ -62,6 +62,7 @@ The available `ignored_layers` names depend on the model architecture (e.g., `to |-------|-----------|---------------|------------------| | Z-Image | `Tongyi-MAI/Z-Image-Turbo` | All layers | None | | Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Skip sensitive layers | `img_mlp` | +| Flux | `black-forest-labs/FLUX.1-dev` | All layers | None | ## Combining with Other Features diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py index faf6d08d3a..db6f0d34ec 100644 --- a/vllm_omni/diffusion/models/flux/flux_transformer.py +++ b/vllm_omni/diffusion/models/flux/flux_transformer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F @@ -17,9 +17,17 @@ from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.layers.rope import RotaryEmbedding @@ -28,7 +36,15 @@ class ColumnParallelApproxGELU(nn.Module): - def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + def __init__( + self, + dim_in: int, + dim_out: int, + *, + approximate: str, + bias: bool = True, + quant_config: "QuantizationConfig | None" = None, + ): super().__init__() self.proj = ColumnParallelLinear( dim_in, @@ -36,6 +52,7 @@ def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = bias=bias, gather_output=False, return_bias=False, + quant_config=quant_config, ) self.approximate = approximate @@ -53,6 +70,7 @@ def __init__( activation_fn: str = "gelu-approximate", inner_dim: int | None = None, bias: bool = True, + quant_config: "QuantizationConfig | None" = None, ) -> None: super().__init__() @@ -62,13 +80,14 @@ def __init__( dim_out = dim_out or dim layers: list[nn.Module] = [ - ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias, quant_config=quant_config), nn.Identity(), # placeholder for weight loading RowParallelLinear( inner_dim, dim_out, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ), ] @@ -95,6 +114,7 @@ def __init__( out_dim: int = None, context_pre_only: bool | None = None, pre_only: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -118,6 +138,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=bias, + quant_config=quant_config, ) if not self.pre_only: @@ -129,6 +150,7 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ), nn.Dropout(dropout), ] @@ -143,6 +165,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=added_proj_bias, + quant_config=quant_config, ) self.to_add_out = RowParallelLinear( @@ -151,6 +174,7 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -233,7 +257,13 @@ def forward( class FluxTransformerBlock(nn.Module): def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -249,13 +279,14 @@ def __init__( context_pre_only=False, bias=True, eps=eps, + quant_config=quant_config, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim) + self.ff = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config) self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim) + self.ff_context = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config) def forward( self, @@ -315,14 +346,25 @@ def forward( class FluxSingleTransformerBlock(nn.Module): - def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + quant_config: "QuantizationConfig | None" = None, + ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm = AdaLayerNormZeroSingle(dim) - self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.proj_mlp = ReplicatedLinear( + dim, self.mlp_hidden_dim, bias=True, return_bias=False, quant_config=quant_config + ) self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + self.proj_out = ReplicatedLinear( + dim + self.mlp_hidden_dim, dim, bias=True, return_bias=False, quant_config=quant_config + ) self.attn = FluxAttention( query_dim=dim, @@ -332,6 +374,7 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, bias=True, eps=1e-6, pre_only=True, + quant_config=quant_config, ) def forward( @@ -432,6 +475,10 @@ class FluxTransformer2DModel(nn.Module): # -- typically a transformer layer # used for torch compile optimizations _repeated_blocks = ["FluxTransformerBlock"] + packed_modules_mapping = { + "to_qkv": ["to_q", "to_k", "to_v"], + "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], + } def __init__( self, @@ -447,6 +494,7 @@ def __init__( pooled_projection_dim: int = 768, guidance_embeds: bool = True, axes_dims_rope: tuple[int, int, int] = (16, 56, 56), + quant_config: "QuantizationConfig | None" = None, ): super().__init__() model_config = od_config.tf_model_config @@ -474,6 +522,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, + quant_config=quant_config, ) for _ in range(num_layers) ] @@ -485,6 +534,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, + quant_config=quant_config, ) for _ in range(num_single_layers) ] diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py index b67c8eb874..d4793d34ec 100644 --- a/vllm_omni/diffusion/models/flux/pipeline_flux.py +++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py @@ -27,6 +27,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.flux import FluxTransformer2DModel +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific @@ -166,7 +167,8 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = FluxTransformer2DModel(od_config=od_config) + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.transformer = FluxTransformer2DModel(od_config=od_config, quant_config=quant_config) self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) self.tokenizer_2 = T5TokenizerFast.from_pretrained(