Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/user_guide/diffusion/quantization/fp8.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ The available `ignored_layers` names depend on the model architecture (e.g., `to
|-------|-----------|---------------|------------------|
| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | All layers | None |
| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Skip sensitive layers | `img_mlp` |
| Flux | `black-forest-labs/FLUX.1-dev` | All layers | None |

## Combining with Other Features

Expand Down
70 changes: 60 additions & 10 deletions vllm_omni/diffusion/models/flux/flux_transformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Any
from typing import TYPE_CHECKING, Any

import torch
import torch.nn.functional as F
Expand All @@ -17,9 +17,17 @@
from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig

from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
Expand All @@ -28,14 +36,23 @@


class ColumnParallelApproxGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True):
def __init__(
self,
dim_in: int,
dim_out: int,
*,
approximate: str,
bias: bool = True,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.proj = ColumnParallelLinear(
dim_in,
dim_out,
bias=bias,
gather_output=False,
return_bias=False,
quant_config=quant_config,
)
self.approximate = approximate

Expand All @@ -53,6 +70,7 @@ def __init__(
activation_fn: str = "gelu-approximate",
inner_dim: int | None = None,
bias: bool = True,
quant_config: "QuantizationConfig | None" = None,
) -> None:
super().__init__()

Expand All @@ -62,13 +80,14 @@ def __init__(
dim_out = dim_out or dim

layers: list[nn.Module] = [
ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias),
ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias, quant_config=quant_config),
nn.Identity(), # placeholder for weight loading
RowParallelLinear(
inner_dim,
dim_out,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
),
]

Expand All @@ -95,6 +114,7 @@ def __init__(
out_dim: int = None,
context_pre_only: bool | None = None,
pre_only: bool = False,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()

Expand All @@ -118,6 +138,7 @@ def __init__(
head_size=self.head_dim,
total_num_heads=self.heads,
bias=bias,
quant_config=quant_config,
)

if not self.pre_only:
Expand All @@ -129,6 +150,7 @@ def __init__(
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
),
nn.Dropout(dropout),
]
Expand All @@ -143,6 +165,7 @@ def __init__(
head_size=self.head_dim,
total_num_heads=self.heads,
bias=added_proj_bias,
quant_config=quant_config,
)

self.to_add_out = RowParallelLinear(
Expand All @@ -151,6 +174,7 @@ def __init__(
bias=out_bias,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
)

self.rope = RotaryEmbedding(is_neox_style=False)
Expand Down Expand Up @@ -233,7 +257,13 @@ def forward(

class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()

Expand All @@ -249,13 +279,14 @@ def __init__(
context_pre_only=False,
bias=True,
eps=eps,
quant_config=quant_config,
)

self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim)
self.ff = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config)

self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim)
self.ff_context = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config)

def forward(
self,
Expand Down Expand Up @@ -315,14 +346,25 @@ def forward(


class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)

self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.proj_mlp = ReplicatedLinear(
dim, self.mlp_hidden_dim, bias=True, return_bias=False, quant_config=quant_config
)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
self.proj_out = ReplicatedLinear(
dim + self.mlp_hidden_dim, dim, bias=True, return_bias=False, quant_config=quant_config
)

self.attn = FluxAttention(
query_dim=dim,
Expand All @@ -332,6 +374,7 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
bias=True,
eps=1e-6,
pre_only=True,
quant_config=quant_config,
)

def forward(
Expand Down Expand Up @@ -432,6 +475,10 @@ class FluxTransformer2DModel(nn.Module):
# -- typically a transformer layer
# used for torch compile optimizations
_repeated_blocks = ["FluxTransformerBlock"]
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
}

def __init__(
self,
Expand All @@ -447,6 +494,7 @@ def __init__(
pooled_projection_dim: int = 768,
guidance_embeds: bool = True,
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
model_config = od_config.tf_model_config
Expand Down Expand Up @@ -474,6 +522,7 @@ def __init__(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
quant_config=quant_config,
)
for _ in range(num_layers)
]
Expand All @@ -485,6 +534,7 @@ def __init__(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
quant_config=quant_config,
)
for _ in range(num_single_layers)
]
Expand Down
4 changes: 3 additions & 1 deletion vllm_omni/diffusion/models/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.flux import FluxTransformer2DModel
from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific

Expand Down Expand Up @@ -166,7 +167,8 @@ def __init__(
self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
self.transformer = FluxTransformer2DModel(od_config=od_config)
quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config)
self.transformer = FluxTransformer2DModel(od_config=od_config, quant_config=quant_config)

self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(
Expand Down