Skip to content

[Quantization] Enable FP8 weight storage for Qwen Image VAE and text encoder#1414

Closed
lishunyang12 wants to merge 3 commits into
vllm-project:mainfrom
lishunyang12:fp8-vae-qwen-image
Closed

[Quantization] Enable FP8 weight storage for Qwen Image VAE and text encoder#1414
lishunyang12 wants to merge 3 commits into
vllm-project:mainfrom
lishunyang12:fp8-vae-qwen-image

Conversation

@lishunyang12
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 commented Feb 20, 2026

Summary

  • Add FP8 weight-only storage for Linear/Conv2d/Conv3d layers in the Qwen Image VAE and text encoder (Qwen2_5_VLForConditionalGeneration)
  • Weights are stored in float8_e4m3fn with per-tensor scales and dequantized to BF16 before each forward pass — saving ~50% memory for these components with no accuracy loss
  • Applied to all three Qwen Image pipelines: generation, edit, and edit-plus
  • Fixed load_weights() to mark VAE and text_encoder parameters as loaded (they use from_pretrained(), not the weight pipeline)

Approach

New utility apply_fp8_weight_storage() in vllm_omni/diffusion/models/utils.py:

  • Walks all nn.Linear, nn.Conv2d, and nn.Conv3d modules (including subclasses like QwenImageCausalConv3d)
  • Computes per-tensor scale, converts weight to FP8, stores FP8 weight + scale as buffers
  • Registers forward pre-hook (dequantize to BF16) and post-hook (re-quantize to FP8)
  • Peak memory per layer: FP8 + BF16 weight only during that layer's forward; all other layers stay in FP8

This covers all three non-DiT components when --quantization fp8 is passed:

Component Quantization method
DiT (transformer) vLLM FP8 linear layers (PR #1338)
Text encoder FP8 weight storage (this PR)
VAE FP8 weight storage (this PR)

Test plan

  • Run Qwen Image text-to-image with --quantization fp8 and verify output quality matches BF16 baseline
  • Verify memory reduction for VAE and text encoder vs without --quantization fp8
  • Run pre-commit checks (ruff lint + format)

Related

Follow-up to #1338 (FP8 online quantization for DiT transformer)

…encoder

Add FP8 weight-only quantization for Linear/Conv2d/Conv3d layers in the
Qwen Image VAE and text encoder. Weights are stored in float8_e4m3fn and
dequantized to BF16 before each forward pass, saving ~50% memory for
these components with no accuracy loss.

Signed-off-by: lishunyang <lishunyang12@163.com>
@lishunyang12 lishunyang12 changed the title [Quantization] Enable FP8 weight storage for Qwen Image VAE [Quantization] Enable FP8 weight storage for Qwen Image VAE and text encoder Feb 20, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9e048e151f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/models/utils.py Outdated
module._fp8_compute_dtype = compute_dtype

# Replace the parameter data with FP8 to save memory
module.weight.data = fp8_weight
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve VAE runtime dtype when storing FP8 weights

Assigning module.weight.data = fp8_weight permanently changes each conv parameter to FP8, and the Qwen Image pipelines later call latents = latents.to(self.vae.dtype) before decode. Because model dtype is parameter-derived, enabling quantization can push VAE inputs to FP8 and break decode paths that rely on BF16/FP16 convolution execution. Keep exposed parameter dtype at the original compute dtype and store FP8 only in auxiliary buffers so self.vae.dtype remains valid.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the latest push. module.weight.data now stays at compute dtype (BF16/FP16). The FP8 representation is stored in the _fp8_weight buffer and only dequantized into module.weight.data inside the pre-hook. The post-hook writes back mod._fp8_weight.to(mod._fp8_compute_dtype) instead of raw FP8, so model.dtype stays correct at all times.

Also added an idempotency guard (hasattr(module, "_fp8_weight")) to prevent double-application.

Comment thread vllm_omni/diffusion/models/utils.py Outdated

def _pre_hook(mod, args):
# Dequantize: restore BF16/FP16 weight for conv computation
mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep dequantized weights in original compute dtype

The pre-hook dequantization multiplies mod._fp8_weight.to(mod._fp8_compute_dtype) by _fp8_scale stored as float32, which promotes the result to float32 instead of restoring the original BF16/FP16 dtype. Under FP8 mode this increases memory/compute cost and can cause dtype mismatches with BF16 activations or biases during VAE decode. Cast the scale (or final product) back to mod._fp8_compute_dtype before assigning mod.weight.data.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The dequantized result is now explicitly cast back to compute_dtype after multiplication with the float32 scale:

# Before
mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale

# After
mod.weight.data = (
    mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
).to(mod._fp8_compute_dtype)

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

for the weights, i remember there are two versions for both qwen-image and qwen-image-edit.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

@github-actions
Copy link
Copy Markdown

🤖 VLLM-Omni PR Review

Code Review: FP8 Weight Storage for Qwen Image VAE and Text Encoder

1. Overview

This PR implements FP8 weight-only storage for Linear/Conv2d/Conv3d layers in Qwen Image VAE and text encoder components. The approach stores weights in float8_e4m3fn format with per-tensor scales, dequantizing to BF16 during forward passes and re-quantizing afterward. This targets ~50% memory savings for these components.

Overall Assessment: Changes Requested - While the approach is sound and the implementation is mostly correct, there are several issues that should be addressed before merging, including a potential memory leak during exceptions and missing validation.


2. Code Quality

Critical Issues

a) Exception Safety - Memory Leak Risk

vllm_omni/diffusion/models/utils.py:58-64

If an exception occurs during the forward pass, the post_hook never executes, leaving weights in BF16 format. This could cause OOM on subsequent forwards.

def _pre_hook(mod, args):
    mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale

def _post_hook(mod, args, output):
    mod.weight.data = mod._fp8_weight

Suggested Fix:

def _pre_hook(mod, args):
    mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale

def _post_hook(mod, args, output):
    # Use .data to avoid tracking in autograd graph
    mod.weight.data = mod._fp8_weight
    return output

# Also consider using try/finally pattern or a context manager approach

b) Division by Zero Edge Case

vllm_omni/diffusion/models/utils.py:48-49

If all weights are zero, amax becomes 0 after clamping, resulting in scale = 0, which causes NaN during dequantization.

amax = weight.abs().amax().clamp(min=1e-12)
scale = amax / _FP8_E4M3_MAX

While clamp(min=1e-12) handles this, consider adding an explicit check with a warning for zero weights.

c) Missing Quantization Type Check

pipeline_qwen_image.py:277-279 (and similar in other files)

The code applies FP8 storage whenever quantization_config is not None, but doesn't verify it's actually FP8:

if od_config.quantization_config is not None:
    apply_fp8_weight_storage(self.vae)

Suggested Fix:

if od_config.quantization_config is not None:
    quant_method = getattr(od_config.quantization_config, "quant_method", None)
    if quant_method == "fp8":
        apply_fp8_weight_storage(self.vae)
        apply_fp8_weight_storage(self.text_encoder)

Minor Issues

d) Duplicate Code Across Pipelines

The same 6-line block is duplicated in three pipeline files. Consider extracting to a helper method.

e) No Idempotency Check

utils.py:45

If apply_fp8_weight_storage is called twice on the same model, it would create duplicate hooks and buffers.

def apply_fp8_weight_storage(model: nn.Module) -> None:
    count = 0
    for name, module in model.named_modules():
        if not isinstance(module, _FP8_TARGET_LAYERS):
            continue
        
        # Add check for already quantized modules
        if hasattr(module, "_fp8_weight"):
            continue

3. Architecture & Design

Positive Aspects

  • Clean separation of concerns with the utility function
  • Non-invasive approach using hooks rather than modifying model definitions
  • Good use of buffers for storing quantized weights

Concerns

a) Hook-Based Weight Modification is Fragile

Directly modifying module.weight.data in hooks works but is unconventional. Consider alternatives:

  1. Custom Module Wrapper: Wrap Linear/Conv layers with a quantized variant
  2. Functional Approach: Pass dequantized weight explicitly in forward

b) No Cleanup/Removal Mechanism

There's no way to:

  • Remove FP8 storage after application
  • Disable quantization temporarily
  • Inspect quantization state

Suggested Addition:

def remove_fp8_weight_storage(model: nn.Module) -> None:
    """Remove FP8 weight storage and restore original weights."""
    for name, module in model.named_modules():
        if hasattr(module, "_fp8_weight"):
            # Restore original weight
            module.weight.data = module._fp8_weight.to(module._fp8_compute_dtype) * module._fp8_scale
            del module._fp8_weight
            del module._fp8_scale
            del module._fp8_compute_dtype
            # Remove hooks (would need to store handles)

4. Security & Safety

Resource Management

a) Memory Spike During Conversion

utils.py:51-54

The conversion process creates temporary tensors that could cause memory spikes:

fp8_weight = (weight / scale).clamp(min=-_FP8_E4M3_MAX, max=_FP8_E4M3_MAX).to(torch.float8_e4m3fn)

Consider processing layers sequentially with explicit garbage collection for large models.

b) No Validation of Input Model State

The function doesn't check if:

  • Model is on the correct device
  • Model is in eval/train mode (affects batch norm, etc.)
  • Weights are already quantized by another method

5. Testing & Documentation

Test Coverage Gaps

The test plan mentions manual verification but lacks:

  1. Unit tests for apply_fp8_weight_storage:

    • Test with zero weights
    • Test with various dtypes (FP16, BF16, FP32)
    • Test idempotency
    • Test memory usage before/after
  2. Integration tests:

    • Test with --quantization fp8 flag
    • Test with other quantization methods (should be no-op)
    • Test exception recovery
  3. Accuracy tests:

    • Numerical comparison between FP8 and BF16 outputs

Documentation

The docstring is good but could be improved:

def apply_fp8_weight_storage(model: nn.Module) -> None:
    """Apply FP8 weight-only storage to Linear/Conv2d/Conv3d layers.

    Stores weights in float8_e4m3fn with per-tensor scales.
    Dequantizes to the original compute dtype before each forward pass,
    then re-quantizes afterward to free BF16 memory.

    This saves ~50% of memory with no accuracy loss since computation
    still happens in the original dtype.

    Args:
        model: The model whose layers will be quantized.
    
    Note:
        - This modifies modules in-place and registers forward hooks.
        - Calling multiple times on the same model is safe (no-op for 
          already quantized layers).
        - Requires weights to be in BF16 or FP16 format.
    
    Warning:
        If an exception occurs during forward pass, weights may remain
        in dequantized form, potentially causing memory issues.
    """

6. Specific Suggestions

vllm_omni/diffusion/models/utils.py

Line Issue Suggestion
45 Missing idempotency check Add if hasattr(module, "_fp8_weight"): continue
48 Edge case handling Add warning for near-zero amax values
51 Memory efficiency Consider with torch.no_grad(): context
55 Missing attribute Store compute dtype as buffer for consistency
58-64 Exception safety Consider using try/finally or storing hook handles

pipeline_qwen_image.py (and similar files)

Line Issue Suggestion
277 Missing type check Verify quant_method == "fp8"
730-731 Performance Use set union operator `

Code Example for Improved Implementation

def apply_fp8_weight_storage(model: nn.Module) -> None:
    """Apply FP8 weight-only storage to Linear/Conv2d/Conv3d layers."""
    count = 0
    for name, module in model.named_modules():
        if not isinstance(module, _FP8_TARGET_LAYERS):
            continue
        
        # Skip already quantized modules
        if hasattr(module, "_fp8_weight"):
            logger.debug("Skipping already quantized layer: %s", name)
            continue

        with torch.no_grad():
            weight = module.weight.data
            compute_dtype = weight.dtype
            
            # Skip unsupported dtypes
            if compute_dtype not in (torch.bfloat16, torch.float16, torch.float32):
                logger.warning("Skipping layer %s with unsupported dtype %s", name, compute_dtype)
                continue

            # Compute per-tensor scale
            amax = weight.abs().amax()
            if amax < 1e-12:
                logger.warning("Layer %s has near-zero weights, using minimum scale", name)
                amax = torch.tensor(1e-12, device=weight.device, dtype=weight.dtype)
            scale = amax / _FP8_E4M3_MAX

            # Quantize weight to FP8
            fp8_weight = (weight / scale).clamp(
                min=-_FP8_E4M3_MAX, max=_FP8_E4M3_MAX
            ).to(torch.float8_e4m3fn)

            # Store FP8 weight and metadata
            module.register_buffer("_fp8_weight", fp8_weight, persistent=False)
            module.register_buffer("_fp8_scale", scale.to(torch.float32), persistent=False)
            module._fp8_compute_dtype = compute_dtype

            # Replace weight data with FP8
            module.weight.data = fp8_weight

            # Register hooks with handles for potential cleanup
            pre_handle = module.register_forward_pre_hook(_create_pre_hook())
            post_handle = module.register_forward_hook(_create_post_hook())
            module._fp8_hook_handles = (pre_handle, post_handle)
            
            count += 1

    logger.info("Applied FP8 weight storage to %d layers in %s", count, model.__class__.__name__)


def _create_pre_hook():
    """Factory to create pre-hook with proper closure."""
    def _pre_hook(mod, args):
        mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
        return args
    return _pre_hook


def _create_post_hook():
    """Factory to create post-hook with proper closure."""
    def _post_hook(mod, args, output):
        mod.weight.data = mod._fp8_weight
        return output
    return _post_hook

7. Approval Status

Changes Requested

The PR has a solid approach and good documentation, but requires addressing these issues before merging:

Must Fix:

  1. Add quantization type check (quant_method == "fp8") before applying FP8 storage
  2. Add idempotency check to prevent double quantization
  3. Add unit tests for the new utility function

Should Fix:

  1. Add exception safety mechanism for hooks
  2. Add validation for unsupported dtypes
  3. Extract duplicate code into a shared helper method

Nice to Have:

  1. Add cleanup/removal function
  2. Store hook handles for potential removal
  3. Add warning for near-zero weight layers

The memory savings approach is valuable and the implementation is close to being production-ready. Once the critical issues are addressed, this would be a welcome addition to the codebase.


This review was generated automatically by the VLLM-Omni PR Reviewer Bot
using glm-5.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

@github-actions
Copy link
Copy Markdown

🤖 VLLM-Omni PR Review

Code Review: FP8 Weight Storage for Qwen Image VAE and Text Encoder

1. Overview

This PR implements FP8 weight-only storage for the VAE and text encoder components in Qwen Image pipelines. The approach stores weights in float8_e4m3fn format with per-tensor scales, dequantizing to BF16 on-the-fly during forward passes. This achieves ~50% memory savings for these components while maintaining accuracy.

Overall Assessment: Positive with suggestions

The implementation is well-structured and follows a sensible approach. The changes are consistent across all three pipeline files, and the utility function is cleanly separated. However, there are a few areas that could benefit from refinement.


2. Code Quality

Strengths

  • Clean separation of concerns with the utility function
  • Consistent implementation across all three pipeline files
  • Good documentation in the docstring
  • Appropriate use of buffers vs parameters

Potential Issues

Hook closure pattern: The _pre_hook and _post_hook functions are defined inside the loop and capture mod via closure. While this works correctly, it's a pattern that can lead to subtle bugs and makes the code harder to reason about.

Hook accumulation: If apply_fp8_weight_storage is called multiple times on the same model (e.g., in testing scenarios), hooks would accumulate without any cleanup mechanism.

Weight data assignment timing: At utils.py:56, module.weight.data = fp8_weight is set, but this happens before hooks are registered. If any operation accesses the weight between this line and the first forward pass, it would see FP8 data.


3. Architecture & Design

Integration

The integration with existing code is clean:

  • Properly checks od_config.quantization_config before applying
  • Correctly handles the load_weights() method to account for from_pretrained() usage

Design Pattern Concerns

In-place weight modification: The hooks modify module.weight.data in-place. This pattern:

  1. Is not thread-safe
  2. Could cause issues with gradient computation if used in training mode
  3. May interact poorly with model introspection tools

Suggested improvement: Consider using a more robust hook implementation:

class FP8WeightHook:
    def __init__(self, fp8_weight: torch.Tensor, scale: torch.Tensor, compute_dtype: torch.dtype):
        self.fp8_weight = fp8_weight
        self.scale = scale
        self.compute_dtype = compute_dtype
    
    def pre_forward(self, module, args):
        module.weight.data = self.fp8_weight.to(self.compute_dtype) * self.scale
    
    def post_forward(self, module, args, output):
        module.weight.data = self.fp8_weight

4. Security & Safety

No significant security concerns. The code operates on model weights internally.

Resource management consideration: The hooks modify GPU memory during forward passes. Under memory pressure, the allocation of BF16 weights could potentially cause OOM errors that wouldn't occur with the original implementation. Consider adding a note about this in documentation.


5. Testing & Documentation

Test Plan

The test plan in the PR description is appropriate but incomplete:

  • Quality comparison with BF16 baseline
  • Memory reduction verification
  • Pre-commit checks

Missing test considerations:

  • Edge case: What happens with zero-initialized weights?
  • Edge case: What happens if the model is saved and reloaded?
  • Performance benchmark: Is there any latency impact from the dequantization?

Documentation

The docstring is good but could be enhanced with:

  • Warning about thread safety
  • Note about memory behavior during forward pass
  • Example usage

6. Specific Suggestions

vllm_omni/diffusion/models/utils.py

Line 46-47: Consider adding a check for already-quantized modules:

if hasattr(module, '_fp8_weight'):
    logger.debug("Skipping %s, already quantized", name)
    continue

Line 49: The clamp(min=1e-12) is good for numerical stability, but consider logging a warning if the clamp is triggered (indicating near-zero weights):

if amax < 1e-12:
    logger.warning("Layer %s has near-zero weights (amax=%s)", name, amax)

Lines 58-65: Consider storing hook handles for potential cleanup:

pre_handle = module.register_forward_pre_hook(_pre_hook)
post_handle = module.register_forward_hook(_post_hook)
module._fp8_hooks = (pre_handle, post_handle)

Line 56: This assignment could be moved after hook registration for consistency, or add a comment explaining why it's needed before hooks:

# Initialize weight to FP8 - will be dequantized by pre-hook on first forward
module.weight.data = fp8_weight

Pipeline files (pipeline_qwen_image.py, etc.)

Lines 277-278, 254-255, 204-205: The import inside the conditional is fine, but consider moving it to the top of the file for consistency with Python conventions, or add a comment explaining why it's conditional:

# Lazy import to avoid circular dependencies / only needed when quantizing

Lines 730-733, 829-832, 783-786: The load_weights modification is correct but could be more explicit:

# VAE and text_encoder are loaded via from_pretrained() during __init__,
# not through the weight pipeline. Mark them as loaded to prevent warnings.

7. Approval Status

LGTM with suggestions

The PR is fundamentally sound and implements a useful feature. The core approach is correct and the integration is clean. The suggestions above are improvements rather than blocking issues.

Recommendations before merge:

  1. Consider the hook class pattern for better maintainability (optional but recommended)
  2. Add a check for already-quantized modules to prevent double-quantization
  3. Ensure the test plan is completed and results documented

Minor nitpicks that don't block merge:

  • Type hints could be more complete
  • Consider adding a utility function to remove FP8 hooks if needed for testing

This review was generated automatically by the VLLM-Omni PR Reviewer Bot
using glm-5.

- P1: Keep parameter dtype at compute dtype so model.dtype stays correct
- P2: Cast dequantized weight back to compute dtype (avoid float32 promotion)
- P3: Add quant_method == "fp8" guard before applying FP8 storage
- P4: Add idempotency check to prevent double-application of hooks

Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request Feb 24, 2026
- Fix wording: "DiT linear layers" -> "DiT" (per SamitHuang)
- Add NPU compatibility row to device compatibility table (per SamitHuang)
- Mark unmerged models as "Planned" with PR links in supported models tables
- Remove Qwen-Image-Edit, Edit-Plus, Wan 2.2 from acceleration table until PRs merge
- Add PR vllm-project#1414 dependency note for apply_fp8_weight_storage in contributing guide
- Add quant_method guard in code example
- Fix reference implementations table accuracy

Signed-off-by: lishunyang <lishunyang12@163.com>
@lishunyang12
Copy link
Copy Markdown
Collaborator Author

for the weights, i remember there are two versions for both qwen-image and qwen-image-edit.

@hsliuustc0106 Could you clarify which two versions you mean? Is this about different checkpoint formats (e.g., safetensors vs pytorch_model.bin), different model sizes, or different training stages? I want to make sure we test with both.

Collapse multi-line expression to single line to satisfy ruff formatter.

Signed-off-by: lishunyang <lishunyang12@163.com>
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Hi @lishunyang12 👋

This PR for FP8 weight storage hasn't been updated for 15 days. Just checking if this is still in progress.

Thanks!

@lishunyang12
Copy link
Copy Markdown
Collaborator Author

Merged into #1412 — combining Wan 2.2 FP8 + Qwen Image VAE/encoder FP8 into a single PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants