Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head)#30141
Conversation
|
Documentation preview: https://vllm--30141.org.readthedocs.build/en/30141/ |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
|
Hi @eldarkurtic, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
There was a problem hiding this comment.
Code Review
This pull request introduces support for per-attention-head FP8 KV cache quantization, primarily for use with Flash Attention and llm-compressor. The changes involve modifying CUDA kernels to handle both per-tensor and per-attention-head scaling factors, adding a new kernel for per-channel static FP8 quantization, and updating the reshape_and_cache_flash_kernel to use a kv_scale_stride for flexible scale access. The Python _custom_ops are updated to support per-channel scales, and the attention layer initialization logic is refactored to correctly set KV cache quantization attributes and query quantization group shapes based on the loaded llm-compressor configuration. Model weight loading utilities are extended to remap q_scale and zero_point parameters, and a new CompressedTensorsConfig method is added to process llm-compressor loaded scales, including reducing q_scale for Flash Attention and repeating it for QuantFP8 operations. Comprehensive unit tests were added to cover per-attention-head scaling in reshape_and_cache_flash and per-channel static FP8 quantization. The documentation for FP8 KV Cache has been significantly updated to detail per-attention-head quantization, various scale calibration approaches, and provides an example using llm-compressor. A review comment highlighted a brittle logic in CompressedTensorsConfig.from_config for filtering attention quantization config groups, suggesting a more robust check for empty target lists and iterating through all targets.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
Outdated
Show resolved
Hide resolved
|
Hi @eldarkurtic, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
|
Hi @eldarkurtic, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
|
This pull request has merge conflicts that must be resolved before it can be |
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
Outdated
Show resolved
Hide resolved
09ff5ce to
0dbf3be
Compare
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
cabdcc9 to
cf5c848
Compare
…llm-project#30141) Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: 陈建华 <1647430658@qq.com>
…llm-project#30141) Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
…llm-project#30141) Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
Adapt the update in vllm-project/vllm#30141 ```python # llm-compressor mdls need to set cache_dtype to "fp8" manually. if getattr(quant_config, "kv_cache_scheme", None) is not None: kv_cache_dtype = "fp8" calculate_kv_scales = False if cache_config is not None: cache_config.cache_dtype = "fp8" cache_config.calculate_kv_scales = False self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( kv_cache_dtype, vllm_config.model_config ) self.kv_cache_dtype = kv_cache_dtype ``` cc @hshen14 @thuang6 @lkk12014402 --------- Signed-off-by: yiliu30 <yi4.liu@intel.com>
Adapt the update in vllm-project/vllm#30141 ```python # llm-compressor mdls need to set cache_dtype to "fp8" manually. if getattr(quant_config, "kv_cache_scheme", None) is not None: kv_cache_dtype = "fp8" calculate_kv_scales = False if cache_config is not None: cache_config.cache_dtype = "fp8" cache_config.calculate_kv_scales = False self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( kv_cache_dtype, vllm_config.model_config ) self.kv_cache_dtype = kv_cache_dtype ``` cc @hshen14 @thuang6 @lkk12014402 --------- Signed-off-by: yiliu30 <yi4.liu@intel.com> Signed-off-by: slokesha <slokeshappa@habana.ai>
…llm-project#30141) Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
…llm-project#30141) Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
TLDR: this PR adds support to load and run
llm-compressormodels with FP8 KV-cache and attention quantization. In addition to the standard "per-tensor" quantization, it adds support for "per-attention-head" quantization.Summary
llm-compressor2.1 for query quantization:
QuantFP8to support per-channel static quantization (queries are of the shapenum_tokens x hidden_sizeso we expand the per-attention-head scales to accommodate per-channel scalingstatic_scaled_fp8_quantkernel to work with an array of scales2.2 for kv-quantization:
reshape_and_cache_flashkernel by adding support for an array ofk/v_scalevllm/attention/layer.pyas so far it has been hardcoded forcalculate_kv_scalespathway onlyTests
To confirm that the existing pathway with
calculate_kv_scales=True/Falseisn't affected, I ran GSM8k evals on the following models:Llama-2-7b-chat-hf(MHA),Llama-3.1-8B-Instruct(GQA),Qwen/Qwen3-8B(different model family), and got the same results before and after the PR. They are as follows:meta-llama/Llama-2-7b-chat-hfNote: The score of
0incalculate_kv_scales=Trueis already present onmain; this PR does not introduce it.meta-llama/Llama-3.1-8B-InstructQwen/Qwen3-8BAnd to confirm that the new support for
llm-compressormodels with both,per-tensorandper-attention-headscales is working correctly, I ran the same models from above with both configurations and observed the expected results:meta-llama/Llama-2-7b-chat-hfmeta-llama/Llama-3.1-8B-InstructQwen/Qwen3-8BThe
llm-compressorcode to produce these models is available in diffs ofdocs/features/quantization/quantized_kvcache.md. I haven't done any tuning of the calibration parameters for the testing purposes, just ran the defaults so better results are expected with better tuning.I've also verified that changes are working with both, LLM class and vllm serve.
Note: some model implementations have remapping of scales guarded by something like this:
if "scale" in nameso I had to expand it toif "scale" in name or "zero_point" in name:to support loading of zero_points which are present inllm-compressorcheckpoints.