Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
432 changes: 432 additions & 0 deletions examples/quantization/quantize_wan2_2_modelopt_fp8.py

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions tests/diffusion/model_loader/test_diffusers_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch
import torch.nn as nn
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config

from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter
Expand Down Expand Up @@ -93,3 +94,36 @@ def test_qwen_model_class_selects_qwen_gguf_adapter():
adapter = get_gguf_adapter("dummy.gguf", object(), source, od_config)

assert adapter.__class__.__name__ == "QwenImageGGUFAdapter"


def test_loader_auto_detects_quant_config_from_transformer_config():
od_config = type(
"Config",
(),
{
"quantization_config": None,
"tf_model_config": type(
"TransformerConfig",
(),
{
"quant_config": ModelOptFp8Config.from_config(
{
"quant_method": "modelopt",
"quant_algo": "FP8",
"ignore": [],
}
),
"quant_method": "modelopt",
},
)(),
"set_tf_model_config": lambda self, tf_model_config: setattr(
self,
"quantization_config",
tf_model_config.quant_config,
),
},
)()

DiffusersPipelineLoader._auto_detect_quant_config(od_config)

assert od_config.quantization_config is od_config.tf_model_config.quant_config
94 changes: 94 additions & 0 deletions tests/diffusion/model_loader/test_modelopt_fp8_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import pytest
import torch
import torch.nn as nn

from vllm_omni.diffusion.model_loader.checkpoint_adapters import (
ModelOptFp8CheckpointAdapter,
)

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]


class _PackedModelOptModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.transformer = nn.Module()
self.transformer.block = nn.Module()
self.transformer.block.to_qkv = nn.Linear(2, 2, bias=False)


class _QuantizedPackedModelOptModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.transformer = nn.Module()
self.transformer.block = nn.Module()
self.transformer.block.to_qkv = nn.Module()
self.transformer.block.to_qkv.register_parameter(
"weight",
nn.Parameter(torch.empty(2, 2, dtype=torch.float8_e4m3fn), requires_grad=False),
)
self.transformer.block.to_qkv.register_parameter(
"weight_scale",
nn.Parameter(torch.empty(1), requires_grad=False),
)
self.transformer.block.to_qkv.register_parameter(
"input_scale",
nn.Parameter(torch.empty(1), requires_grad=False),
)


def _make_source() -> SimpleNamespace:
return SimpleNamespace(
subfolder="transformer",
prefix="transformer.",
)


def test_modelopt_adapter_dequantizes_fp8_weight_for_full_precision_target():
model = _PackedModelOptModel()
adapter = ModelOptFp8CheckpointAdapter(model, _make_source())
fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor([0.5], dtype=torch.float32)

adapted = list(
adapter.adapt(
iter(
[
("transformer.block.to_q.weight_scale", scale),
("transformer.block.to_q.input_scale", torch.tensor([1.0])),
("transformer.block.to_q.weight", fp8_weight),
]
)
)
)

assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"]
assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype
assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale)


def test_modelopt_adapter_keeps_scale_tensors_for_quantized_target():
model = _QuantizedPackedModelOptModel()
adapter = ModelOptFp8CheckpointAdapter(model, _make_source())
scale = torch.tensor([0.5], dtype=torch.float32)

adapted = list(
adapter.adapt(
iter(
[
("transformer.block.to_q.weight_scale", scale),
("transformer.block.to_q.input_scale", torch.tensor([1.0])),
]
)
)
)

assert [name for name, _ in adapted] == [
"transformer.block.to_q.weight_scale",
"transformer.block.to_q.input_scale",
]
41 changes: 40 additions & 1 deletion tests/diffusion/quantization/test_fp8_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ def test_build_quant_config_dict_not_mutated():
assert original == copy


def test_build_quant_config_modelopt_fp8_config_json():
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config

from vllm_omni.quantization import build_quant_config

config = build_quant_config(
{
"quant_method": "modelopt",
"quant_algo": "FP8",
"ignore": ["proj_out"],
"producer": {"name": "modelopt"},
}
)

assert isinstance(config, ModelOptFp8Config)
assert config.get_name() == "modelopt"
assert config.is_checkpoint_fp8_serialized


def test_build_quant_config_per_component():
from vllm_omni.quantization import ComponentQuantizationConfig, build_quant_config

Expand Down Expand Up @@ -91,7 +110,7 @@ def test_flat_dict_not_misdetected_as_per_component():
as a per-component dict — it should raise ValueError for missing 'method'."""
from vllm_omni.quantization import build_quant_config

with pytest.raises(ValueError, match="must have a 'method' key"):
with pytest.raises(ValueError, match="must have a 'method' or 'quant_method' key"):
build_quant_config({"activation_scheme": "static"})


Expand Down Expand Up @@ -194,6 +213,26 @@ def test_integration_per_component():
assert config.quantization_config.component_configs["vae"] is None


def test_transformer_config_auto_detects_modelopt_fp8():
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config

from vllm_omni.diffusion.data import TransformerConfig

config = TransformerConfig.from_dict(
{
"_class_name": "FluxTransformer2DModel",
"quantization_config": {
"quant_method": "modelopt",
"quant_algo": "FP8",
"ignore": ["proj_out"],
},
}
)

assert isinstance(config.quant_config, ModelOptFp8Config)
assert config.quant_method == "modelopt"


def test_supported_methods_includes_vllm():
from vllm_omni.quantization import SUPPORTED_QUANTIZATION_METHODS

Expand Down
30 changes: 29 additions & 1 deletion tests/diffusion/test_diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from pytest_mock import MockerFixture

from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker
from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker, _make_diffusion_vllm_model_config

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]

Expand Down Expand Up @@ -78,6 +78,34 @@ def test_load_weights_empty_iterable(self, mocker: MockerFixture, mock_gpu_worke
assert result == set()


def test_diffusion_vllm_model_config_supplies_dtype_for_quant_methods():
from types import SimpleNamespace

from vllm_omni.quantization import build_quant_config

od_config = SimpleNamespace(
model="dummy",
dtype=torch.bfloat16,
quantization_config=build_quant_config(
{
"quant_method": "modelopt",
"quant_algo": "FP8",
"ignore": [],
}
),
tf_model_config=SimpleNamespace(),
enforce_eager=True,
is_moe=False,
)

model_config = _make_diffusion_vllm_model_config(od_config)

assert model_config.dtype is torch.bfloat16
assert model_config.quantization == "modelopt"
assert model_config.quantization_config is od_config.quantization_config
assert model_config.is_quantized()


class TestDiffusionWorkerSleep:
"""Test DiffusionWorker.sleep method."""

Expand Down
52 changes: 33 additions & 19 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,11 @@ def from_dict(cls, data: dict[str, Any]) -> "TransformerConfig":
quant_method: str | None = None
quant_config: QuantizationConfig | None = None
disk_qc = params.get("quantization_config")
if isinstance(disk_qc, dict) and "quant_method" in disk_qc:
quant_method = disk_qc["quant_method"]
kwargs = {k: v for k, v in disk_qc.items() if k != "quant_method"}
quant_config = build_quant_config(quant_method, **kwargs)
if isinstance(disk_qc, dict):
raw_quant_method = disk_qc.get("quant_method", disk_qc.get("method"))
quant_config = build_quant_config(disk_qc)
if quant_config is not None:
quant_method = raw_quant_method if raw_quant_method is not None else quant_config.get_name()

return cls(params=params, quant_method=quant_method, quant_config=quant_config)

Expand Down Expand Up @@ -616,14 +617,9 @@ def __post_init__(self):

# Auto-detect quantization from TransformerConfig if not explicitly set.
# This covers the case where tf_model_config is passed at construction
# time. For late (post-construction) assignment, callers should use
# time. For late (post-construction) assignment, callers should use
# set_tf_model_config() which propagates quant_config automatically.
if self.quantization_config is None and self.tf_model_config.quant_config is not None:
self.quantization_config = self.tf_model_config.quant_config
logger.info(
"Auto-detected quantization '%s' from model config",
self.tf_model_config.quant_method,
)
self._propagate_quantization_from_tf_config(self.tf_model_config)

# Resolve quantization_config: str/dict -> QuantizationConfig via build_quant_config.
if self.quantization_config is not None:
Expand All @@ -644,6 +640,29 @@ def __post_init__(self):
elif self.max_cpu_loras < 1:
raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA")

def _propagate_quantization_from_tf_config(self, tf_config: "TransformerConfig") -> None:
if tf_config.quant_config is None:
return

is_checkpoint_fp8 = bool(getattr(tf_config.quant_config, "is_checkpoint_fp8_serialized", False))
should_use_checkpoint_config = self.quantization_config is None or (
is_checkpoint_fp8 and self._is_generic_fp8_quant_config(self.quantization_config)
)
if should_use_checkpoint_config:
self.quantization_config = tf_config.quant_config
logger.info(
"Auto-detected quantization '%s' from model config",
tf_config.quant_method,
)

@staticmethod
def _is_generic_fp8_quant_config(quant_config: object) -> bool:
if isinstance(quant_config, str):
return quant_config.lower() == "fp8"
if hasattr(quant_config, "get_name"):
return quant_config.get_name() == "fp8"
return False

def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
"""Assign `tf_model_config` and propagate quantization if detected.

Expand All @@ -659,12 +678,7 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
`TransformerConfig.from_dict`.
"""
self.tf_model_config = tf_config
if self.quantization_config is None and tf_config.quant_config is not None:
self.quantization_config = tf_config.quant_config
logger.info(
"Auto-detected quantization '%s' from model config",
tf_config.quant_method,
)
self._propagate_quantization_from_tf_config(tf_config)

def update_multimodal_support(self) -> None:
# Resolve serving-visible multimodal behavior from shared metadata
Expand All @@ -690,15 +704,15 @@ def enrich_config(self) -> None:
self.update_multimodal_support()

tf_config_dict = get_hf_file_to_dict("transformer/config.json", self.model)
self.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
self.set_tf_model_config(TransformerConfig.from_dict(tf_config_dict))
else:
raise FileNotFoundError("model_index.json not found")
except (AttributeError, OSError, ValueError, FileNotFoundError):
cfg = get_hf_file_to_dict("config.json", self.model)
if cfg is None:
raise ValueError(f"Could not find config.json or model_index.json for model {self.model}")

self.tf_model_config = TransformerConfig.from_dict(cfg)
self.set_tf_model_config(TransformerConfig.from_dict(cfg))
model_type = cfg.get("model_type")
architectures = cfg.get("architectures") or []

Expand Down
20 changes: 18 additions & 2 deletions vllm_omni/diffusion/layers/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ def apply_rotary_emb_mindiesd(
return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_half", head_first=False, fused=True)


def _ensure_batch_dim(x: torch.Tensor) -> tuple[torch.Tensor, bool]:
if x.dim() == 3:
return x.unsqueeze(0), True
return x, False


def _restore_batch_dim(x: torch.Tensor, squeezed: bool) -> torch.Tensor:
if squeezed:
return x.squeeze(0)
return x


class RotaryEmbedding(CustomOp):
"""
rotary positional embedding.
Expand Down Expand Up @@ -98,12 +110,14 @@ def forward_cuda(
cos = cos[0]
sin = sin[0]

return apply_rotary_emb(
x, squeezed = _ensure_batch_dim(x)
output = apply_rotary_emb(
x,
cos,
sin,
interleaved=self.interleaved,
)
return _restore_batch_dim(output, squeezed)

def forward_hip(
self,
Expand All @@ -119,12 +133,14 @@ def forward_hip(
cos = cos[0]
sin = sin[0]

return self.apply_rotary_emb_flash_attn(
x, squeezed = _ensure_batch_dim(x)
output = self.apply_rotary_emb_flash_attn(
x,
cos,
sin,
interleaved=self.interleaved,
)
return _restore_batch_dim(output, squeezed)

def forward_npu(
self,
Expand Down
Loading
Loading