[Quantization] Support NVFP4 for Flux.2#2517
Conversation
Add NVIDIA FP4 (NVFP4) quantization support for Flux.2 diffusion models, enabling running FLUX.2-dev-NVFP4 pre-quantized checkpoints via ModelOpt. Key changes: - Add BFL checkpoint weight name mapping (WeightsMapper) to Flux2Transformer2DModel.load_weights() with auto-detection - Add transformer_weights_path to OmniDiffusionConfig for loading transformer weights from a separate checkpoint path - Add NVFP4 auto-detection from safetensors file headers - Update FluxPipeline and Flux2KleinPipeline to support separate transformer weight paths - Add unit tests for weight mapping, format detection, and auto-detection Leverages vLLM's existing ModelOptNvFp4Config and ModelOptNvFp4LinearMethod registered as "modelopt_fp4" — no custom quantization config needed. Reference: sgl-project/sglang#20137 Signed-off-by: lishunyang <lishunyang12@163.com>
a31e944 to
56bb20d
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a31e9441e9
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for name, weight in weights: | ||
| if not is_bfl and (name.startswith("double_blocks.") or name.startswith("single_blocks.")): | ||
| is_bfl = True | ||
| buffered.append((name, weight)) |
There was a problem hiding this comment.
Stop materializing the entire checkpoint for format detection
_is_bfl_format currently consumes the whole incoming weight iterator and stores every tensor in buffered before loading begins. Because load_weights() calls this unconditionally, Flux2 checkpoint loading now keeps the full state dict in memory at once (instead of streaming), which can significantly increase peak host RAM and cause OOM on large NVFP4/Flux.2 checkpoints.
Useful? React with 👍 / 👎.
| p = Path(path) | ||
| if p.is_file() and p.suffix == ".safetensors": | ||
| return [str(p)] | ||
| if p.is_dir(): | ||
| files = sorted(str(f) for f in p.glob("*.safetensors")) |
There was a problem hiding this comment.
Handle remote model IDs in NVFP4 auto-detection
The auto-detection path in Flux2KleinPipeline passes transformer_weights_path directly to detect_nvfp4_from_safetensors, but _find_safetensors_files only recognizes existing local files/directories. When users provide a Hugging Face repo ID (which is supported later by the loader), this returns no files, so auto-detection never enables modelopt_fp4 and NVFP4 checkpoints can fail unless quantization is set manually.
Useful? React with 👍 / 👎.
|
Closed for now. Will resume in the near future. |
|
I think for any quantization method, it's better to:
these tests should be very general and can be applied to any quant precision |
I will include those tests in this PR and update related documents for developers who are interested in this exploring this feature. |
…pport Signed-off-by: lishunyang <lishunyang12@163.com>
…ude, klein-4b smoke target Signed-off-by: lishunyang <lishunyang12@163.com>
…uild Signed-off-by: lishunyang <lishunyang12@163.com>
… init Signed-off-by: lishunyang <lishunyang12@163.com>
…ipeline build Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
… for proper NVFP4 post-load Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…nstruct Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…SP/...) Signed-off-by: lishunyang <lishunyang12@163.com>
…d_environment Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…ript + standard loader Signed-off-by: lishunyang <lishunyang12@163.com>
…e quant_config from transformer/config.json Signed-off-by: lishunyang <lishunyang12@163.com>
… for BFL ckps Signed-off-by: lishunyang <lishunyang12@163.com>
…-picklable) Signed-off-by: lishunyang <lishunyang12@163.com>
…match Signed-off-by: lishunyang <lishunyang12@163.com>
…() call Signed-off-by: lishunyang <lishunyang12@163.com>
… swizzle (sgl #22064) Signed-off-by: lishunyang <lishunyang12@163.com>
…er() Signed-off-by: lishunyang <lishunyang12@163.com>
…vocations Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…n to avoid HF cache pollution Signed-off-by: lishunyang <lishunyang12@163.com>
…g #20137 + #22064 Signed-off-by: lishunyang <lishunyang12@163.com>
…_weights_after_loading is called Signed-off-by: lishunyang <lishunyang12@163.com>
…SGLang Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…alpha diagnostic Signed-off-by: lishunyang <lishunyang12@163.com>
…p4_quantize + mm_fp4 with .T Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
- Add BFL→diffusers name remapping to Flux2Transformer2DModel.load_weights (was only wired into flux2_klein, not flux2_dev) - Add scalar scale broadcast for NVFP4 fused projections - Fix merge script to skip stale .safetensors.index.json Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
NVFP4 quantization was not being applied because quant_config was never passed to the parallel linear layers. This caused Flux2Attention.to_out and other layers to allocate BF16 params while the checkpoint contained packed FP4 weights, producing shape mismatches at load time. Mirrors the pattern already used in flux2_klein. Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
vLLM's ModelOpt get_quant_method uses the layer prefix to check against exclude_modules wildcards. Without a correct prefix, is_layer_excluded returns False for every layer and NVFP4 gets applied to modules the checkpoint kept in BF16 (e.g. txt_attn on FLUX.2-dev-NVFP4). Signed-off-by: lishunyang <lishunyang12@163.com>
BFL's FLUX.2-dev-NVFP4 release is officially only compatible with comfy-kitchen kernels (see the HF model card discussion and SGLang PR #20137, which admits the generic ModelOpt FP4 path produces wrong output for BFL's checkpoint layout). Instead of fighting vLLM's quant_method dispatch, this follows the approach laid out in Comfy-Org/comfy-kitchen/samples/nvfp4_linear.py: - New NVFP4Linear(nn.Module) wraps weights as comfy-kitchen QuantizedTensor and dispatches F.linear to scaled_mm_nvfp4. - Flux2Transformer2DModel gains a use_comfy_nvfp4 flag; when set (at TP=1), every fused projection (QKV, gate+up, out, qkv_mlp) is allocated as a single NVFP4Linear matching the BFL ckp layout. - Flux2Pipeline auto-detects quant_method="modelopt_fp4_flux" from the transformer config and passes the flag through. - load_weights finalizes the model by walking all NVFP4Linear instances and wrapping their raw uint8 weights as QuantizedTensors. Requires: pip install comfy-kitchen[cublas] on Blackwell (SM >= 10.0). Signed-off-by: lishunyang <lishunyang12@163.com>
f5ec6b7 to
ab55419
Compare
FLUX.2-dev-NVFP4's exclude_modules leaves txt_attn in BF16. Allocate those two layers as plain nn.Linear (not NVFP4Linear) when use_comfy_nvfp4 is set, matching the checkpoint layout. Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12
left a comment
There was a problem hiding this comment.
Review: [Quantization] Support NVFP4 for Flux.2
Thanks for the detailed PR description. The architecture is well-documented, the test coverage for config/mapping is solid, and the BFL weight-name remapping extraction into bfl_mapping.py is a clean refactor. Here are my findings:
Blocking / Must-Fix
1. Two parallel NVFP4 execution paths with unclear routing
The PR introduces two independent NVFP4 linear implementations:
comfy_nvfp4_linear.py(NVFP4Linear) — used whenuse_comfy_nvfp4=True, wired into the model constructorflux_nvfp4.py(FluxNvFp4LinearMethod) — registered as theLinearMethodClson theModelOptNvFp4Configreturned by_build_modelopt_nvfp4_flux
When use_comfy_nvfp4=True, all quantized layers become NVFP4Linear (bypassing vLLM's quant dispatch entirely), so FluxNvFp4LinearMethod is never actually called for those layers. Meanwhile, the quant_config is still passed to the constructor and forwarded to the non-NVFP4 parallel linears (add_kv_proj, etc.), where is_layer_excluded returns True so they stay BF16 — that part works. But it means flux_nvfp4.py (273 lines of flashinfer-based GEMM logic) is dead code in the current comfy-kitchen path. Either:
- Remove
flux_nvfp4.pyif comfy-kitchen is the only intended backend, or - Document clearly which path is active and under what conditions, and add a test that exercises
FluxNvFp4LinearMethod.applyso it doesn't bitrot.
2. init_timeout=1200 hardcoded unconditionally in the example script
In text_to_image.py, the timeout is bumped for all models, not just NVFP4. A 21 GB single-file load is NVFP4-specific; other models (e.g., FLUX.2-dev BF16 or GGUF) don't need 20 minutes. This should be conditional on detecting NVFP4, or better yet, passed via a CLI flag (e.g., --init-timeout).
3. _HANDSHAKE_POLL_TIMEOUT_S doubled globally
Same concern: stage_diffusion_proc.py timeout goes from 600s to 1200s for all models. This masks genuine startup failures for non-NVFP4 models. Consider making this configurable or deriving it from the quant method.
Non-Blocking / Suggestions
4. Significant code duplication in load_weights between flux2_transformer.py and flux2_klein_transformer.py
The BFL detection, scalar-scale broadcast, and error-wrapping logic is copy-pasted across both files. Consider extracting a shared helper (e.g., in bfl_mapping.py or a new weight_loading_utils.py) that both transformers call. The _peek_bfl_format / _apply_bfl_mapping staticmethod aliasing pattern already exists — extend it to cover the load loop body.
5. import logging placement in flux2_transformer.py
The import logging appears after all other imports including the local from vllm_omni.diffusion.utils.bfl_mapping import ... block. Standard style (and PEP 8) places stdlib imports before third-party and local imports. Move it to the stdlib import block.
6. Missing TYPE_CHECKING guard formatting
import torch
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from diffusers.models.embeddings import (The if TYPE_CHECKING block runs into the next import without a blank line, making it look like from diffusers... is inside the guard. Add a blank line after the if TYPE_CHECKING block.
7. FluxNvFp4LinearMethod.process_weights_after_loading calls .cuda() directly
padded_scales = padded_scales.contiguous().cuda()This bypasses vLLM's device management. If the model is loaded on a specific GPU via CUDA_VISIBLE_DEVICES or distributed placement, .cuda() may target GPU 0 instead of the correct device. Use padded_scales.to(layer.weight.device) or the platform's device accessor.
8. _read_hf_header uses raw requests.get without retry
HuggingFace CDN can be flaky. Consider using huggingface_hub's built-in download utilities or adding a simple retry. Also, the timeout=30 may be too short for large headers on slow connections.
9. PR status: output is known-broken
The description states the output is "a black image with a horizontal-line artifact" and lists several suspected root causes (input_scale shape, AdaLN swap, RMSNorm rename). Merging this will leave a non-functional code path in main. Consider gating the entire NVFP4 path behind a clear "experimental" warning log at startup, or marking the smoke test with @pytest.mark.xfail until numerical correctness is confirmed.
10. _TupleLinear is a clever workaround but fragile
Subclassing nn.Linear just to return (output, None) couples the implementation to the current call-site convention. If any caller changes to expect a plain tensor (or a different tuple structure), this will silently produce wrong results. A comment at each call site noting the expected return type would help future maintainers.
Minor Nits
tools/prepare_flux2_nvfp4.pymodifiessys.pathtwice (lines ~697 and ~800). Consider doing it once at the top.resolve_nvfp4_checkpoint_filereturnsOptional[str]but callers don't always guard forNone(e.g., the default in--nvfp4-fileprocessing does check, but it's worth a type annotation on the return).- The test file
tests/diffusion/quantization/test_nvfp4_config.pyimportsFlux2Transformer2DModelfromflux2_klein(line 163 in the diff). This seems intentional (shared interface), but the import path is confusing — a comment explaining why would help.
Overall: solid groundwork for NVFP4 support. The main concern is shipping two parallel NVFP4 backends with one being dead code, and the unconditional timeout increases. Once those are addressed and numerical correctness is verified, this is ready to merge.
Purpose
Add NVFP4 loading/inference support for
black-forest-labs/FLUX.2-dev-NVFP4in vllm-omni.Per BFL's own HF model card discussion, this checkpoint is officially only compatible with the ComfyUI /
comfy-kitchenkernel stack. SGLang's PR #20137 documents that the generic ModelOpt FP4 / flashinfermm_fp4path produces semantically wrong output for BFL'sweight_scalelayout. This PR takes the samecomfy-kitchen-backed approach.What's in this PR
Infrastructure / loading (5 real bugs)
tools/prepare_flux2_nvfp4.pyskips the stale.safetensors.index.json— the base's sharded-weights index referenced BF16 shards that don't exist in the merged dir, breaking the loader.Flux2Transformer2DModel.load_weights— previously onlyflux2_kleinhad it, so FLUX.2-dev NVFP4 weights couldn't be found at all.input_scale/weight_scale_2when the BFL checkpoint stores them as 0-d tensors but vLLM'sPerTensorScaleParameterexpects shape[N].quant_configthreaded through all transformer modules (Flux2FeedForward,Flux2Attention,Flux2ParallelSelfAttention, both block types, top-level model). Klein had this; dev didn't, so parallel linears always allocated BF16 even when NVFP4 was configured.prefixthreaded through all transformer modules so vLLM'sModelOpt.is_layer_excludedwildcard matching can see the full layer path.Timeouts
init_timeoutand_HANDSHAKE_POLL_TIMEOUT_Sbumped to 1200 s — the 21 GB single-file NVFP4 load can exceed the 600 s default on first run.NVFP4 execution path (comfy-kitchen)
vllm_omni/quantization/comfy_nvfp4_linear.py:NVFP4Linear(nn.Module) — adapted from comfy-kitchen's sample. Wraps weights asQuantizedTensor(..., "TensorCoreNVFP4Layout", ...)and letsF.lineardispatch toscaled_mm_nvfp4.wrap_all_nvfp4_weightsfinalizes every layer after weight load.Flux2Transformer2DModelgains ause_comfy_nvfp4flag. When set (TP=1), fused projections (QKV, gate+up,to_qkv_mlp_proj, out) are allocated asNVFP4Linearmatching the BFL checkpoint's pre-fused layout.txt_attn(add_kv_proj,to_add_out) staysnn.Linear/ BF16, matching the checkpoint'sexclude_modules.Flux2Pipelineauto-detectsquant_method == \"modelopt_fp4_flux\"fromtransformer/config.jsonand flips the flag.Diagnostics helper
tools/dump_flux2_nvfp4_metadata.pyprints the injectedquantization_configblock and the embedded_quantization_metadataheader — useful for verifying the merged model before running inference.Requirements
pip install comfy-kitchen[cublas](for NVFP4 kernels)Usage
```bash
python tools/prepare_flux2_nvfp4.py
--base black-forest-labs/FLUX.2-dev
--nvfp4 black-forest-labs/FLUX.2-dev-NVFP4
--output-dir ./flux2-dev-nvfp4-merged
python examples/offline_inference/text_to_image/text_to_image.py
--model ./flux2-dev-nvfp4-merged
--prompt "a photo of an astronaut riding a horse on Mars"
--num-inference-steps 20
--enforce-eager
--output out.png
```
Status
input_scale/weight_scale_2shape handling when buildingTensorCoreNVFP4Layout.Params[scale, shift]→[shift, scale]swap fornorm_out.linear.weight(present inbfl_mapping.apply_bfl_mapping; needs double-check along the comfy-kitchen path)RMSNorm.scale→.weightrename interaction with the plainnn.LinearsublayersThe bug is now isolated: architecture matches what BFL and SGLang document as the only working path. Further debugging needs intermediate-tensor comparison against a BF16 reference forward.
Follow-ups
comfy-kitchen[cublas]install as a hard requirement for NVFP4 on Blackwell.