Skip to content

[Quantization] Support NVFP4 for Flux.2#2517

Closed
lishunyang12 wants to merge 43 commits into
vllm-project:mainfrom
lishunyang12:feat/nvfp4-flux2
Closed

[Quantization] Support NVFP4 for Flux.2#2517
lishunyang12 wants to merge 43 commits into
vllm-project:mainfrom
lishunyang12:feat/nvfp4-flux2

Conversation

@lishunyang12
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 commented Apr 6, 2026

Purpose

Add NVFP4 loading/inference support for black-forest-labs/FLUX.2-dev-NVFP4 in vllm-omni.

Per BFL's own HF model card discussion, this checkpoint is officially only compatible with the ComfyUI / comfy-kitchen kernel stack. SGLang's PR #20137 documents that the generic ModelOpt FP4 / flashinfer mm_fp4 path produces semantically wrong output for BFL's weight_scale layout. This PR takes the same comfy-kitchen-backed approach.

What's in this PR

Infrastructure / loading (5 real bugs)

  1. tools/prepare_flux2_nvfp4.py skips the stale .safetensors.index.json — the base's sharded-weights index referenced BF16 shards that don't exist in the merged dir, breaking the loader.
  2. BFL → diffusers name remapping wired into Flux2Transformer2DModel.load_weights — previously only flux2_klein had it, so FLUX.2-dev NVFP4 weights couldn't be found at all.
  3. Scalar scale broadcast for input_scale / weight_scale_2 when the BFL checkpoint stores them as 0-d tensors but vLLM's PerTensorScaleParameter expects shape [N].
  4. quant_config threaded through all transformer modules (Flux2FeedForward, Flux2Attention, Flux2ParallelSelfAttention, both block types, top-level model). Klein had this; dev didn't, so parallel linears always allocated BF16 even when NVFP4 was configured.
  5. prefix threaded through all transformer modules so vLLM's ModelOpt.is_layer_excluded wildcard matching can see the full layer path.

Timeouts

  1. init_timeout and _HANDSHAKE_POLL_TIMEOUT_S bumped to 1200 s — the 21 GB single-file NVFP4 load can exceed the 600 s default on first run.

NVFP4 execution path (comfy-kitchen)

  1. New vllm_omni/quantization/comfy_nvfp4_linear.py:
    • NVFP4Linear (nn.Module) — adapted from comfy-kitchen's sample. Wraps weights as QuantizedTensor(..., "TensorCoreNVFP4Layout", ...) and lets F.linear dispatch to scaled_mm_nvfp4.
    • wrap_all_nvfp4_weights finalizes every layer after weight load.
  2. Flux2Transformer2DModel gains a use_comfy_nvfp4 flag. When set (TP=1), fused projections (QKV, gate+up, to_qkv_mlp_proj, out) are allocated as NVFP4Linear matching the BFL checkpoint's pre-fused layout. txt_attn (add_kv_proj, to_add_out) stays nn.Linear / BF16, matching the checkpoint's exclude_modules.
  3. Flux2Pipeline auto-detects quant_method == \"modelopt_fp4_flux\" from transformer/config.json and flips the flag.

Diagnostics helper

  1. tools/dump_flux2_nvfp4_metadata.py prints the injected quantization_config block and the embedded _quantization_metadata header — useful for verifying the merged model before running inference.

Requirements

  • Blackwell (SM ≥ 10.0) — B100 / B200 / RTX 5090
  • pip install comfy-kitchen[cublas] (for NVFP4 kernels)

Usage

```bash
python tools/prepare_flux2_nvfp4.py
--base black-forest-labs/FLUX.2-dev
--nvfp4 black-forest-labs/FLUX.2-dev-NVFP4
--output-dir ./flux2-dev-nvfp4-merged

python examples/offline_inference/text_to_image/text_to_image.py
--model ./flux2-dev-nvfp4-merged
--prompt "a photo of an astronaut riding a horse on Mars"
--num-inference-steps 20
--enforce-eager
--output out.png
```

Status

  • Model loads on B200; every layer's allocated param shape matches the corresponding checkpoint tensor.
  • Forward runs end-to-end, 20 steps complete without crash.
  • Output is not yet a valid image — current result is a black image with a horizontal-line artifact. Likely remaining numerical issues:
    • input_scale / weight_scale_2 shape handling when building TensorCoreNVFP4Layout.Params
    • AdaLN [scale, shift][shift, scale] swap for norm_out.linear.weight (present in bfl_mapping.apply_bfl_mapping; needs double-check along the comfy-kitchen path)
    • RMSNorm .scale.weight rename interaction with the plain nn.Linear sublayers
    • VAE / text-encoder dtype interaction with the comfy-kitchen-quantized transformer output

The bug is now isolated: architecture matches what BFL and SGLang document as the only working path. Further debugging needs intermediate-tensor comparison against a BF16 reference forward.

Follow-ups

  • Verify numerical output then add sample generation to the PR.
  • Benchmark memory + latency vs BF16 on B200.
  • Document comfy-kitchen[cublas] install as a hard requirement for NVFP4 on Blackwell.

@lishunyang12 lishunyang12 changed the title [Diffusion] Support NVFP4 quantization for Flux.2 [Quantization] Support NVFP4 for Flux.2 Apr 6, 2026
Add NVIDIA FP4 (NVFP4) quantization support for Flux.2 diffusion models,
enabling running FLUX.2-dev-NVFP4 pre-quantized checkpoints via ModelOpt.

Key changes:
- Add BFL checkpoint weight name mapping (WeightsMapper) to
  Flux2Transformer2DModel.load_weights() with auto-detection
- Add transformer_weights_path to OmniDiffusionConfig for loading
  transformer weights from a separate checkpoint path
- Add NVFP4 auto-detection from safetensors file headers
- Update FluxPipeline and Flux2KleinPipeline to support separate
  transformer weight paths
- Add unit tests for weight mapping, format detection, and auto-detection

Leverages vLLM's existing ModelOptNvFp4Config and ModelOptNvFp4LinearMethod
registered as "modelopt_fp4" — no custom quantization config needed.

Reference: sgl-project/sglang#20137

Signed-off-by: lishunyang <lishunyang12@163.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a31e9441e9

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1021 to +1024
for name, weight in weights:
if not is_bfl and (name.startswith("double_blocks.") or name.startswith("single_blocks.")):
is_bfl = True
buffered.append((name, weight))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Stop materializing the entire checkpoint for format detection

_is_bfl_format currently consumes the whole incoming weight iterator and stores every tensor in buffered before loading begins. Because load_weights() calls this unconditionally, Flux2 checkpoint loading now keeps the full state dict in memory at once (instead of streaming), which can significantly increase peak host RAM and cause OOM on large NVFP4/Flux.2 checkpoints.

Useful? React with 👍 / 👎.

Comment on lines +59 to +63
p = Path(path)
if p.is_file() and p.suffix == ".safetensors":
return [str(p)]
if p.is_dir():
files = sorted(str(f) for f in p.glob("*.safetensors"))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle remote model IDs in NVFP4 auto-detection

The auto-detection path in Flux2KleinPipeline passes transformer_weights_path directly to detect_nvfp4_from_safetensors, but _find_safetensors_files only recognizes existing local files/directories. When users provide a Hugging Face repo ID (which is supported later by the loader), this returns no files, so auto-detection never enables modelopt_fp4 and NVFP4 checkpoints can fail unless quantization is set manually.

Useful? React with 👍 / 👎.

@lishunyang12 lishunyang12 marked this pull request as draft April 6, 2026 09:33
@lishunyang12
Copy link
Copy Markdown
Collaborator Author

Closed for now. Will resume in the near future.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

hsliuustc0106 commented Apr 11, 2026

I think for any quantization method, it's better to:

  1. add a few pixel level ci/cd tests to maintain accuracy comparison wrt to the bf16 baseline.
  2. add the tests for Memory (BF16 → NVFP4) | Speedup | LPIPS metrics

these tests should be very general and can be applied to any quant precision

@lishunyang12
Copy link
Copy Markdown
Collaborator Author

lishunyang12 commented Apr 11, 2026

I think for any quantization method, it's better to:

  1. add a few pixel level ci/cd tests to maintain accuracy comparison wrt to the bf16 baseline.
  2. add the tests for Memory (BF16 → NVFP4) | Speedup | LPIPS metrics

these tests should be very general and can be applied to any quant precision

I will include those tests in this PR and update related documents for developers who are interested in this exploring this feature.

…pport

Signed-off-by: lishunyang <lishunyang12@163.com>
…ude, klein-4b smoke target

Signed-off-by: lishunyang <lishunyang12@163.com>
…uild

Signed-off-by: lishunyang <lishunyang12@163.com>
… init

Signed-off-by: lishunyang <lishunyang12@163.com>
…ipeline build

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
… for proper NVFP4 post-load

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…nstruct

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…SP/...)

Signed-off-by: lishunyang <lishunyang12@163.com>
…d_environment

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…ript + standard loader

Signed-off-by: lishunyang <lishunyang12@163.com>
…e quant_config from transformer/config.json

Signed-off-by: lishunyang <lishunyang12@163.com>
… for BFL ckps

Signed-off-by: lishunyang <lishunyang12@163.com>
…-picklable)

Signed-off-by: lishunyang <lishunyang12@163.com>
…match

Signed-off-by: lishunyang <lishunyang12@163.com>
…() call

Signed-off-by: lishunyang <lishunyang12@163.com>
… swizzle (sgl #22064)

Signed-off-by: lishunyang <lishunyang12@163.com>
…er()

Signed-off-by: lishunyang <lishunyang12@163.com>
…vocations

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…n to avoid HF cache pollution

Signed-off-by: lishunyang <lishunyang12@163.com>
…g #20137 + #22064

Signed-off-by: lishunyang <lishunyang12@163.com>
…_weights_after_loading is called

Signed-off-by: lishunyang <lishunyang12@163.com>
…SGLang

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…alpha diagnostic

Signed-off-by: lishunyang <lishunyang12@163.com>
…p4_quantize + mm_fp4 with .T

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
- Add BFL→diffusers name remapping to Flux2Transformer2DModel.load_weights
  (was only wired into flux2_klein, not flux2_dev)
- Add scalar scale broadcast for NVFP4 fused projections
- Fix merge script to skip stale .safetensors.index.json

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
NVFP4 quantization was not being applied because quant_config was never
passed to the parallel linear layers. This caused Flux2Attention.to_out
and other layers to allocate BF16 params while the checkpoint contained
packed FP4 weights, producing shape mismatches at load time.

Mirrors the pattern already used in flux2_klein.

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
vLLM's ModelOpt get_quant_method uses the layer prefix to check against
exclude_modules wildcards. Without a correct prefix, is_layer_excluded
returns False for every layer and NVFP4 gets applied to modules the
checkpoint kept in BF16 (e.g. txt_attn on FLUX.2-dev-NVFP4).

Signed-off-by: lishunyang <lishunyang12@163.com>
BFL's FLUX.2-dev-NVFP4 release is officially only compatible with
comfy-kitchen kernels (see the HF model card discussion and SGLang
PR #20137, which admits the generic ModelOpt FP4 path produces wrong
output for BFL's checkpoint layout).

Instead of fighting vLLM's quant_method dispatch, this follows the
approach laid out in Comfy-Org/comfy-kitchen/samples/nvfp4_linear.py:

 - New NVFP4Linear(nn.Module) wraps weights as comfy-kitchen
   QuantizedTensor and dispatches F.linear to scaled_mm_nvfp4.
 - Flux2Transformer2DModel gains a use_comfy_nvfp4 flag; when set
   (at TP=1), every fused projection (QKV, gate+up, out, qkv_mlp)
   is allocated as a single NVFP4Linear matching the BFL ckp layout.
 - Flux2Pipeline auto-detects quant_method="modelopt_fp4_flux" from
   the transformer config and passes the flag through.
 - load_weights finalizes the model by walking all NVFP4Linear
   instances and wrapping their raw uint8 weights as QuantizedTensors.

Requires: pip install comfy-kitchen[cublas] on Blackwell (SM >= 10.0).

Signed-off-by: lishunyang <lishunyang12@163.com>
FLUX.2-dev-NVFP4's exclude_modules leaves txt_attn in BF16. Allocate
those two layers as plain nn.Linear (not NVFP4Linear) when use_comfy_nvfp4
is set, matching the checkpoint layout.

Signed-off-by: lishunyang <lishunyang12@163.com>
Copy link
Copy Markdown
Collaborator Author

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: [Quantization] Support NVFP4 for Flux.2

Thanks for the detailed PR description. The architecture is well-documented, the test coverage for config/mapping is solid, and the BFL weight-name remapping extraction into bfl_mapping.py is a clean refactor. Here are my findings:


Blocking / Must-Fix

1. Two parallel NVFP4 execution paths with unclear routing

The PR introduces two independent NVFP4 linear implementations:

  • comfy_nvfp4_linear.py (NVFP4Linear) — used when use_comfy_nvfp4=True, wired into the model constructor
  • flux_nvfp4.py (FluxNvFp4LinearMethod) — registered as the LinearMethodCls on the ModelOptNvFp4Config returned by _build_modelopt_nvfp4_flux

When use_comfy_nvfp4=True, all quantized layers become NVFP4Linear (bypassing vLLM's quant dispatch entirely), so FluxNvFp4LinearMethod is never actually called for those layers. Meanwhile, the quant_config is still passed to the constructor and forwarded to the non-NVFP4 parallel linears (add_kv_proj, etc.), where is_layer_excluded returns True so they stay BF16 — that part works. But it means flux_nvfp4.py (273 lines of flashinfer-based GEMM logic) is dead code in the current comfy-kitchen path. Either:

  • Remove flux_nvfp4.py if comfy-kitchen is the only intended backend, or
  • Document clearly which path is active and under what conditions, and add a test that exercises FluxNvFp4LinearMethod.apply so it doesn't bitrot.

2. init_timeout=1200 hardcoded unconditionally in the example script

In text_to_image.py, the timeout is bumped for all models, not just NVFP4. A 21 GB single-file load is NVFP4-specific; other models (e.g., FLUX.2-dev BF16 or GGUF) don't need 20 minutes. This should be conditional on detecting NVFP4, or better yet, passed via a CLI flag (e.g., --init-timeout).

3. _HANDSHAKE_POLL_TIMEOUT_S doubled globally

Same concern: stage_diffusion_proc.py timeout goes from 600s to 1200s for all models. This masks genuine startup failures for non-NVFP4 models. Consider making this configurable or deriving it from the quant method.


Non-Blocking / Suggestions

4. Significant code duplication in load_weights between flux2_transformer.py and flux2_klein_transformer.py

The BFL detection, scalar-scale broadcast, and error-wrapping logic is copy-pasted across both files. Consider extracting a shared helper (e.g., in bfl_mapping.py or a new weight_loading_utils.py) that both transformers call. The _peek_bfl_format / _apply_bfl_mapping staticmethod aliasing pattern already exists — extend it to cover the load loop body.

5. import logging placement in flux2_transformer.py

The import logging appears after all other imports including the local from vllm_omni.diffusion.utils.bfl_mapping import ... block. Standard style (and PEP 8) places stdlib imports before third-party and local imports. Move it to the stdlib import block.

6. Missing TYPE_CHECKING guard formatting

import torch

if TYPE_CHECKING:
    from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from diffusers.models.embeddings import (

The if TYPE_CHECKING block runs into the next import without a blank line, making it look like from diffusers... is inside the guard. Add a blank line after the if TYPE_CHECKING block.

7. FluxNvFp4LinearMethod.process_weights_after_loading calls .cuda() directly

padded_scales = padded_scales.contiguous().cuda()

This bypasses vLLM's device management. If the model is loaded on a specific GPU via CUDA_VISIBLE_DEVICES or distributed placement, .cuda() may target GPU 0 instead of the correct device. Use padded_scales.to(layer.weight.device) or the platform's device accessor.

8. _read_hf_header uses raw requests.get without retry

HuggingFace CDN can be flaky. Consider using huggingface_hub's built-in download utilities or adding a simple retry. Also, the timeout=30 may be too short for large headers on slow connections.

9. PR status: output is known-broken

The description states the output is "a black image with a horizontal-line artifact" and lists several suspected root causes (input_scale shape, AdaLN swap, RMSNorm rename). Merging this will leave a non-functional code path in main. Consider gating the entire NVFP4 path behind a clear "experimental" warning log at startup, or marking the smoke test with @pytest.mark.xfail until numerical correctness is confirmed.

10. _TupleLinear is a clever workaround but fragile

Subclassing nn.Linear just to return (output, None) couples the implementation to the current call-site convention. If any caller changes to expect a plain tensor (or a different tuple structure), this will silently produce wrong results. A comment at each call site noting the expected return type would help future maintainers.


Minor Nits

  • tools/prepare_flux2_nvfp4.py modifies sys.path twice (lines ~697 and ~800). Consider doing it once at the top.
  • resolve_nvfp4_checkpoint_file returns Optional[str] but callers don't always guard for None (e.g., the default in --nvfp4-file processing does check, but it's worth a type annotation on the return).
  • The test file tests/diffusion/quantization/test_nvfp4_config.py imports Flux2Transformer2DModel from flux2_klein (line 163 in the diff). This seems intentional (shared interface), but the import path is confusing — a comment explaining why would help.

Overall: solid groundwork for NVFP4 support. The main concern is shipping two parallel NVFP4 backends with one being dead code, and the unconditional timeout increases. Once those are addressed and numerical correctness is verified, this is ready to merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants