Fix noisy warning for uncalibrated q_scale/p_scale#17414
Fix noisy warning for uncalibrated q_scale/p_scale#17414tlrmchlsmth merged 1 commit intovllm-project:mainfrom
Conversation
Signed-off-by: mgoin <mgoin64@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
| f"Using Q scale {q_scale} and prob scale {prob_scale} " | ||
| "with fp8 attention. This may cause accuracy issues. " | ||
| "Please make sure Q/prob scaling factors are " | ||
| f"Using uncalibrated q_scale {q_scale} and/or prob_scale " |
There was a problem hiding this comment.
Why "uncalibrated"? I only see the warning once when I run it, doesn't seem to be very noisy. Depending on which lands first (#17331) can check the VLLM_ROCM_USE_FP8_SCALES flag too, since the warning won't be necessary if VLLM_ROCM_USE_FP8_SCALES=0.
There was a problem hiding this comment.
I even see this when running INT4 models, this is triggered for most quantization methods
INFO 05-07 03:01:24 [gpu_model_runner.py:1360] Starting to load model RedHatAI/Qwen3-30B-A3B-quantized.w4a16...
INFO 05-07 03:01:34 [loader.py:459] Loading weights took 9.47 seconds
WARNING 05-07 03:01:34 [kv_cache.py:128] Using Q scale 1.0 and prob scale 1.0 with fp8 attention. This may cause accuracy issues. Please make sure Q/prob scaling factors are available in the fp8 checkpoint.
There was a problem hiding this comment.
The scales are uncalibrated because they are 1.0
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Thanks, this has been annoying me too
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
When loading any quantized model that supports quantized kv cache, you see this warning even if quantized kv cache isn't enabled (introduced by #15734)
This change moves the warning to only happen when the kv cache is quantized. We could further restrict this to ROCm platforms since that is the only consumer of this atm