support online FP8 quantization for FA on NPU #2236#2640
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
lishunyang12
left a comment
There was a problem hiding this comment.
Early Review -- WIP Online FP8 Quantization for FA on NPU
Thanks for the PR. This is a useful feature -- online FP8 quantization for flash attention on NPU with selective step/layer skipping. Below is an early review with the items I think should be addressed before this leaves WIP status.
Architecture & Design
-
_update_attn_metadatamutates its input in-place but also returns it. The name "update" suggests mutation, but the function also creates a newAttentionMetadatawhenbase is None. This dual behavior is error-prone. Consider either (a) always returning a new copy (safe, no aliasing surprises), or (b) always mutating in-place and requiring a non-None input. The current mixed mode means callers who pass a sharedbase_attn_metadataacross self-attn and cross-attn in the same block will see cross-contamination ofattn_kind/attn_maskbecause the same object is mutated twice.In
wan2_2_transformer.py,WanTransformerBlock.forward()calls_update_attn_metadata(base_attn_metadata, ...)for both self-attn and cross-attn. If the first call mutates the sharedbase_attn_metadata(settingattn_kind="self"), the second call receives that already-mutated object and overwritesattn_kindto"cross". The first attention op has already consumed it, so it works today, but this is fragile. A shallow copy before mutation would be safer:metadata = copy.copy(base) if base is not None else AttentionMetadata()
-
_update_attn_metadatais a module-level private function exported in the__all__-equivalent import inwan2_2_transformer.py. Consider making it a proper method onAttentionMetadata(e.g.,AttentionMetadata.with_updates(...)) to improve discoverability and encapsulation. -
Lazy resolution pattern in
Attention.forward()calls_resolve_kv_cache_dtype()and_resolve_kv_cache_skip_selectors_from_config()on every forward call. After the first resolution, the_resolvedflags short-circuit, but the function-call overhead remains. Consider resolving once in a dedicated setup hook (e.g., first forward only, or apost_initmethod) rather than checking flags in the hot path.
Correctness Concerns
-
forward_fa_quant_npucallsis_quantized_kv_cache(kv_cache_dtype)afterforward_npualready checkedkv_cache_dtype is not None. Ifkv_cache_dtypeis a non-None string that is not in_FP8_KV_LABELS(e.g.,"int8"), the function logs a warning and falls back toforward_fa_npu. However,_handle_kv_cache_dtypein the base classAttentionImplalready cleared unsupported dtypes toNonewith its own warning. This means theforward_fa_quant_npuwarning path is dead code for any dtype that isn't in_supported_kv_cache_dtypes. Consider removing the redundant check or making the intent clearer. -
Global mutable state in
kv_quant_npu.py--_ROT_MATRIXand_IS_NOT_IMPORTEDare module-level globals guarded by a boolean flag but not thread-safe. In async serving with multiple workers, two threads could race on_IS_NOT_IMPORTEDand_ROT_MATRIX. Usethreading.Lockorfunctools.lru_cachefor the import guard, and be aware that_ROT_MATRIX.to(device)reassigns the global without synchronization. -
_ROT_MATRIXdevice migration is lossy.if _ROT_MATRIX.device != device: _ROT_MATRIX = _ROT_MATRIX.to(device)replaces the global with the new-device copy. In multi-device scenarios (e.g., tensor parallel across NPU cards), the second device overwrites the global and breaks the first. Consider using adict[torch.device, torch.Tensor]cache. -
Magic numbers in
npu_fused_infer_attention_score_v2call.pre_tokens=2147483647,next_tokens=2147483647,query_quant_mode=7, etc. are opaque. Please add brief inline comments explaining what these values mean (e.g., "INT32_MAX = no causal masking", "7 = per-block FP8 quantization mode").
Plumbing / Config
-
serve.pyadds--kv-cache-skip-stepsand--kv-cache-skip-layersbut does NOT add--kv-cache-dtypefor the serve entrypoint. The offline example adds all three, but the online serve path only adds the skip selectors. This means users who launch viavllm_omni servecannot enable FP8 KV quantization without a YAML config. Is this intentional? If so, document it; if not, add the missing CLI arg. -
_resolve_stage_configsinasync_omni_engine.pyuseshasattrchecks (not hasattr(cfg.engine_args, "kv_cache_dtype")). SinceOmniDiffusionConfigis a dataclass with defaults,hasattrwill almost always be True. The guard is likely intended to becfg.engine_args.kv_cache_dtype is None, which is already the second condition. Thehasattrcheck is misleading dead code -- consider removing it. -
The
timestep_scalarvariable computed inpipeline_wan2_2_i2v.pyline ~508 is unused. It is assigned but never referenced. Remove it or use it.
Style / Cleanup
-
Double blank line after
forward_fa_quant_npumethod and beforeforward_fa_npuinflash_attn.py-- one method has standard indentation, the other (forward_fa_npu) uses 8-space indentation for its parameter list instead of the 4-space style used everywhere else in this file. Please normalize. -
_parse_selector_indicesis a@staticmethodonAttentionbut has no dependency on the class. It would be cleaner as a module-level utility or on the config dataclass where it is semantically relevant. -
PR description is empty -- Purpose, Test Plan, and Test Result sections are blank. Even for WIP, please add a brief description of the approach and any benchmark numbers you have so far (e.g., memory savings, latency impact on NPU).
Summary
The overall approach is sound: threading KV-cache quantization dtype through attention metadata and gating it per step/layer is a clean design. The main risks are the in-place mutation aliasing in _update_attn_metadata, the thread-unsafe globals in kv_quant_npu.py, and the missing --kv-cache-dtype CLI arg for the serve path. Looking forward to the non-WIP version.
|
@lishunyang12 Thank you for your careful review. This is very helpful to us. After analyzing the benefits, I will revise and submit the formal PR accordingly. |
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
|
@lishunyang12 @gcanlin @hsliuustc0106 |
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Gaohan123
left a comment
There was a problem hiding this comment.
Please add a UT for it. Thanks
I have added UT, please review it again. |
|
vllm 0.20.1.dev0+g88d34c640.d20260511.empty FA-FP8: UT: |
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
david6666666
left a comment
There was a problem hiding this comment.
Some remaining issues.
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
|
LGTM now. @gcanlin @lishunyang12 ptal thx |
|
Otherwise, please add doc such as https://docs.vllm.ai/en/latest/features/quantization/quantized_kvcache/ and |
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
|
@lyj-jjj Could you please test this PR again? I made some refactor and would be better to take another look. |
…-project#2640) Signed-off-by: lyj-jjj <liuyingjun5@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: Jialong Liu <88185941+Galleons2029@users.noreply.github.com>
1. Background and Goal
In generative models, FA accounts for more than 50% of the time when generating 480p videos and more than 70% when generating 720p videos. Therefore, online quantization of FA can significantly reduce the DIT time, which is essential for the support of view generation models.
Based on PR1413, the FA online FP8 quantization capability on the NPU is extended. The goal of this PR is to introduce online FP8-quantized FA on NPU, with optional step/layer-level fallback, so that DiT latency can be significantly reduced while maintaining generation quality as much as possible.
2. Scope
kv_cache_dtype=fp8kv_cache_skip_stepsandkv_cache_skip_layersselectors for fine-grained fallback3. Core Design
3.1 Config and Parameter Plumbing
Introduce and propagate the following parameters from config/entrypoints down to the attention execution layer:
kv_cache_dtype(e.g.,fp8)kv_cache_skip_steps(string index set)kv_cache_skip_layers(string index set)These parameters are carried through the diffusion attention metadata system. At each layer forward, whether FP8 path is enabled is determined by current step/layer.
3.2 Backend Capability Declaration and Safe Fallback
FlashAttentionBackendusessupports_kv_cache_dtype()to declare supported kv-cache dtypes per platform.On NPU, when
kv_cache_dtypein metadata is valid, execution enters the quantized FA path; otherwise it falls back to native FA.This design guarantees:
3.3 NPU Execution Path Routing
forward_npu()routes by attention type and quantization switch:kv_cache_dtype: useforward_fa_quant_npu()forward_fa_npu()In the quantized path,
fp8_rotate_quant_fa()calls the NPU quantized fused operator. Necessary tensor transposes are applied according to layout requirements so that existing model tensor formats remain compatible.3.4 Step/Layer Selective Fallback
In the attention layer,
kv_cache_skip_stepsandkv_cache_skip_layersare parsed into index sets.At runtime, if either skip condition is matched, FP8 is disabled for the current layer/step and execution falls back to native dtype FA.
Value:
4. Compatibility and Risk Control
kv_cache_dtypeis not configured, runtime behavior is identical to previous versions.5. Performance and Effect (Current Test)
In Wan2.2 I2V scenario (
1280x720,61 frames,4 steps):①fa-bf16: dit-14.13s (fa-6.81s, 48.1%)
②fa-fp8: dit-12.75s (fa-5.43s, 42.5%)
summary: dit reduce 1.38s、improve 10.8% (fa reduce 1.38s、improve 25.4%)
①fa-bf16: https://github.com/user-attachments/assets/46daba81-f8c7-4cce-aac9-f09a6f97ed0e
②fa-fp8: https://github.com/user-attachments/assets/64a86c10-818e-4c3a-a920-0397723a0a88
6. Test Plan
kv_cache_dtype=fp8kv_cache_skip_steps/layersfallback works as expected when selectors match7. Usage Example