[Quantization] Enable FP8 weight storage for Qwen Image VAE and text encoder#1414
[Quantization] Enable FP8 weight storage for Qwen Image VAE and text encoder#1414lishunyang12 wants to merge 3 commits into
Conversation
…encoder Add FP8 weight-only quantization for Linear/Conv2d/Conv3d layers in the Qwen Image VAE and text encoder. Weights are stored in float8_e4m3fn and dequantized to BF16 before each forward pass, saving ~50% memory for these components with no accuracy loss. Signed-off-by: lishunyang <lishunyang12@163.com>
9e048e1 to
85806be
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9e048e151f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| module._fp8_compute_dtype = compute_dtype | ||
|
|
||
| # Replace the parameter data with FP8 to save memory | ||
| module.weight.data = fp8_weight |
There was a problem hiding this comment.
Preserve VAE runtime dtype when storing FP8 weights
Assigning module.weight.data = fp8_weight permanently changes each conv parameter to FP8, and the Qwen Image pipelines later call latents = latents.to(self.vae.dtype) before decode. Because model dtype is parameter-derived, enabling quantization can push VAE inputs to FP8 and break decode paths that rely on BF16/FP16 convolution execution. Keep exposed parameter dtype at the original compute dtype and store FP8 only in auxiliary buffers so self.vae.dtype remains valid.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in the latest push. module.weight.data now stays at compute dtype (BF16/FP16). The FP8 representation is stored in the _fp8_weight buffer and only dequantized into module.weight.data inside the pre-hook. The post-hook writes back mod._fp8_weight.to(mod._fp8_compute_dtype) instead of raw FP8, so model.dtype stays correct at all times.
Also added an idempotency guard (hasattr(module, "_fp8_weight")) to prevent double-application.
|
|
||
| def _pre_hook(mod, args): | ||
| # Dequantize: restore BF16/FP16 weight for conv computation | ||
| mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale |
There was a problem hiding this comment.
Keep dequantized weights in original compute dtype
The pre-hook dequantization multiplies mod._fp8_weight.to(mod._fp8_compute_dtype) by _fp8_scale stored as float32, which promotes the result to float32 instead of restoring the original BF16/FP16 dtype. Under FP8 mode this increases memory/compute cost and can cause dtype mismatches with BF16 activations or biases during VAE decode. Cast the scale (or final product) back to mod._fp8_compute_dtype before assigning mod.weight.data.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed. The dequantized result is now explicitly cast back to compute_dtype after multiplication with the float32 scale:
# Before
mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
# After
mod.weight.data = (
mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
).to(mod._fp8_compute_dtype)|
for the weights, i remember there are two versions for both qwen-image and qwen-image-edit. |
|
@vllm-omni-reviewer |
🤖 VLLM-Omni PR ReviewCode Review: FP8 Weight Storage for Qwen Image VAE and Text Encoder1. OverviewThis PR implements FP8 weight-only storage for Linear/Conv2d/Conv3d layers in Qwen Image VAE and text encoder components. The approach stores weights in Overall Assessment: Changes Requested - While the approach is sound and the implementation is mostly correct, there are several issues that should be addressed before merging, including a potential memory leak during exceptions and missing validation. 2. Code QualityCritical Issuesa) Exception Safety - Memory Leak Risk
If an exception occurs during the forward pass, the def _pre_hook(mod, args):
mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
def _post_hook(mod, args, output):
mod.weight.data = mod._fp8_weightSuggested Fix: def _pre_hook(mod, args):
mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
def _post_hook(mod, args, output):
# Use .data to avoid tracking in autograd graph
mod.weight.data = mod._fp8_weight
return output
# Also consider using try/finally pattern or a context manager approachb) Division by Zero Edge Case
If all weights are zero, amax = weight.abs().amax().clamp(min=1e-12)
scale = amax / _FP8_E4M3_MAXWhile c) Missing Quantization Type Check
The code applies FP8 storage whenever if od_config.quantization_config is not None:
apply_fp8_weight_storage(self.vae)Suggested Fix: if od_config.quantization_config is not None:
quant_method = getattr(od_config.quantization_config, "quant_method", None)
if quant_method == "fp8":
apply_fp8_weight_storage(self.vae)
apply_fp8_weight_storage(self.text_encoder)Minor Issuesd) Duplicate Code Across Pipelines The same 6-line block is duplicated in three pipeline files. Consider extracting to a helper method. e) No Idempotency Check
If def apply_fp8_weight_storage(model: nn.Module) -> None:
count = 0
for name, module in model.named_modules():
if not isinstance(module, _FP8_TARGET_LAYERS):
continue
# Add check for already quantized modules
if hasattr(module, "_fp8_weight"):
continue3. Architecture & DesignPositive Aspects
Concernsa) Hook-Based Weight Modification is Fragile Directly modifying
b) No Cleanup/Removal Mechanism There's no way to:
Suggested Addition: def remove_fp8_weight_storage(model: nn.Module) -> None:
"""Remove FP8 weight storage and restore original weights."""
for name, module in model.named_modules():
if hasattr(module, "_fp8_weight"):
# Restore original weight
module.weight.data = module._fp8_weight.to(module._fp8_compute_dtype) * module._fp8_scale
del module._fp8_weight
del module._fp8_scale
del module._fp8_compute_dtype
# Remove hooks (would need to store handles)4. Security & SafetyResource Managementa) Memory Spike During Conversion
The conversion process creates temporary tensors that could cause memory spikes: fp8_weight = (weight / scale).clamp(min=-_FP8_E4M3_MAX, max=_FP8_E4M3_MAX).to(torch.float8_e4m3fn)Consider processing layers sequentially with explicit garbage collection for large models. b) No Validation of Input Model State The function doesn't check if:
5. Testing & DocumentationTest Coverage GapsThe test plan mentions manual verification but lacks:
DocumentationThe docstring is good but could be improved: def apply_fp8_weight_storage(model: nn.Module) -> None:
"""Apply FP8 weight-only storage to Linear/Conv2d/Conv3d layers.
Stores weights in float8_e4m3fn with per-tensor scales.
Dequantizes to the original compute dtype before each forward pass,
then re-quantizes afterward to free BF16 memory.
This saves ~50% of memory with no accuracy loss since computation
still happens in the original dtype.
Args:
model: The model whose layers will be quantized.
Note:
- This modifies modules in-place and registers forward hooks.
- Calling multiple times on the same model is safe (no-op for
already quantized layers).
- Requires weights to be in BF16 or FP16 format.
Warning:
If an exception occurs during forward pass, weights may remain
in dequantized form, potentially causing memory issues.
"""6. Specific Suggestions
|
| Line | Issue | Suggestion |
|---|---|---|
| 45 | Missing idempotency check | Add if hasattr(module, "_fp8_weight"): continue |
| 48 | Edge case handling | Add warning for near-zero amax values |
| 51 | Memory efficiency | Consider with torch.no_grad(): context |
| 55 | Missing attribute | Store compute dtype as buffer for consistency |
| 58-64 | Exception safety | Consider using try/finally or storing hook handles |
pipeline_qwen_image.py (and similar files)
| Line | Issue | Suggestion |
|---|---|---|
| 277 | Missing type check | Verify quant_method == "fp8" |
| 730-731 | Performance | Use set union operator ` |
Code Example for Improved Implementation
def apply_fp8_weight_storage(model: nn.Module) -> None:
"""Apply FP8 weight-only storage to Linear/Conv2d/Conv3d layers."""
count = 0
for name, module in model.named_modules():
if not isinstance(module, _FP8_TARGET_LAYERS):
continue
# Skip already quantized modules
if hasattr(module, "_fp8_weight"):
logger.debug("Skipping already quantized layer: %s", name)
continue
with torch.no_grad():
weight = module.weight.data
compute_dtype = weight.dtype
# Skip unsupported dtypes
if compute_dtype not in (torch.bfloat16, torch.float16, torch.float32):
logger.warning("Skipping layer %s with unsupported dtype %s", name, compute_dtype)
continue
# Compute per-tensor scale
amax = weight.abs().amax()
if amax < 1e-12:
logger.warning("Layer %s has near-zero weights, using minimum scale", name)
amax = torch.tensor(1e-12, device=weight.device, dtype=weight.dtype)
scale = amax / _FP8_E4M3_MAX
# Quantize weight to FP8
fp8_weight = (weight / scale).clamp(
min=-_FP8_E4M3_MAX, max=_FP8_E4M3_MAX
).to(torch.float8_e4m3fn)
# Store FP8 weight and metadata
module.register_buffer("_fp8_weight", fp8_weight, persistent=False)
module.register_buffer("_fp8_scale", scale.to(torch.float32), persistent=False)
module._fp8_compute_dtype = compute_dtype
# Replace weight data with FP8
module.weight.data = fp8_weight
# Register hooks with handles for potential cleanup
pre_handle = module.register_forward_pre_hook(_create_pre_hook())
post_handle = module.register_forward_hook(_create_post_hook())
module._fp8_hook_handles = (pre_handle, post_handle)
count += 1
logger.info("Applied FP8 weight storage to %d layers in %s", count, model.__class__.__name__)
def _create_pre_hook():
"""Factory to create pre-hook with proper closure."""
def _pre_hook(mod, args):
mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale
return args
return _pre_hook
def _create_post_hook():
"""Factory to create post-hook with proper closure."""
def _post_hook(mod, args, output):
mod.weight.data = mod._fp8_weight
return output
return _post_hook7. Approval Status
Changes Requested
The PR has a solid approach and good documentation, but requires addressing these issues before merging:
Must Fix:
- Add quantization type check (
quant_method == "fp8") before applying FP8 storage - Add idempotency check to prevent double quantization
- Add unit tests for the new utility function
Should Fix:
- Add exception safety mechanism for hooks
- Add validation for unsupported dtypes
- Extract duplicate code into a shared helper method
Nice to Have:
- Add cleanup/removal function
- Store hook handles for potential removal
- Add warning for near-zero weight layers
The memory savings approach is valuable and the implementation is close to being production-ready. Once the critical issues are addressed, this would be a welcome addition to the codebase.
This review was generated automatically by the VLLM-Omni PR Reviewer Bot
using glm-5.
|
@vllm-omni-reviewer |
🤖 VLLM-Omni PR ReviewCode Review: FP8 Weight Storage for Qwen Image VAE and Text Encoder1. OverviewThis PR implements FP8 weight-only storage for the VAE and text encoder components in Qwen Image pipelines. The approach stores weights in Overall Assessment: Positive with suggestions The implementation is well-structured and follows a sensible approach. The changes are consistent across all three pipeline files, and the utility function is cleanly separated. However, there are a few areas that could benefit from refinement. 2. Code QualityStrengths
Potential IssuesHook closure pattern: The Hook accumulation: If Weight data assignment timing: At 3. Architecture & DesignIntegrationThe integration with existing code is clean:
Design Pattern ConcernsIn-place weight modification: The hooks modify
Suggested improvement: Consider using a more robust hook implementation: class FP8WeightHook:
def __init__(self, fp8_weight: torch.Tensor, scale: torch.Tensor, compute_dtype: torch.dtype):
self.fp8_weight = fp8_weight
self.scale = scale
self.compute_dtype = compute_dtype
def pre_forward(self, module, args):
module.weight.data = self.fp8_weight.to(self.compute_dtype) * self.scale
def post_forward(self, module, args, output):
module.weight.data = self.fp8_weight4. Security & SafetyNo significant security concerns. The code operates on model weights internally. Resource management consideration: The hooks modify GPU memory during forward passes. Under memory pressure, the allocation of BF16 weights could potentially cause OOM errors that wouldn't occur with the original implementation. Consider adding a note about this in documentation. 5. Testing & DocumentationTest PlanThe test plan in the PR description is appropriate but incomplete:
Missing test considerations:
DocumentationThe docstring is good but could be enhanced with:
6. Specific Suggestions
|
- P1: Keep parameter dtype at compute dtype so model.dtype stays correct - P2: Cast dequantized weight back to compute dtype (avoid float32 promotion) - P3: Add quant_method == "fp8" guard before applying FP8 storage - P4: Add idempotency check to prevent double-application of hooks Signed-off-by: lishunyang <lishunyang12@163.com>
- Fix wording: "DiT linear layers" -> "DiT" (per SamitHuang) - Add NPU compatibility row to device compatibility table (per SamitHuang) - Mark unmerged models as "Planned" with PR links in supported models tables - Remove Qwen-Image-Edit, Edit-Plus, Wan 2.2 from acceleration table until PRs merge - Add PR vllm-project#1414 dependency note for apply_fp8_weight_storage in contributing guide - Add quant_method guard in code example - Fix reference implementations table accuracy Signed-off-by: lishunyang <lishunyang12@163.com>
@hsliuustc0106 Could you clarify which two versions you mean? Is this about different checkpoint formats (e.g., safetensors vs pytorch_model.bin), different model sizes, or different training stages? I want to make sure we test with both. |
Collapse multi-line expression to single line to satisfy ruff formatter. Signed-off-by: lishunyang <lishunyang12@163.com>
|
Hi @lishunyang12 👋 This PR for FP8 weight storage hasn't been updated for 15 days. Just checking if this is still in progress. Thanks! |
|
Merged into #1412 — combining Wan 2.2 FP8 + Qwen Image VAE/encoder FP8 into a single PR. |
Summary
Qwen2_5_VLForConditionalGeneration)float8_e4m3fnwith per-tensor scales and dequantized to BF16 before each forward pass — saving ~50% memory for these components with no accuracy lossload_weights()to mark VAE and text_encoder parameters as loaded (they usefrom_pretrained(), not the weight pipeline)Approach
New utility
apply_fp8_weight_storage()invllm_omni/diffusion/models/utils.py:nn.Linear,nn.Conv2d, andnn.Conv3dmodules (including subclasses likeQwenImageCausalConv3d)This covers all three non-DiT components when
--quantization fp8is passed:Test plan
--quantization fp8and verify output quality matches BF16 baseline--quantization fp8Related
Follow-up to #1338 (FP8 online quantization for DiT transformer)