[Bugfix][Perf] Indexer upcast WK to BF16 for fusion#38928
[Bugfix][Perf] Indexer upcast WK to BF16 for fusion#38928mgoin merged 5 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
I would certainly want to run the OG ckpt as it is with multi-stream. |
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to handle DeepSeek-V2 checkpoints where 'wk' weights are stored in FP8 with separate scales. By dequantizing these weights to BF16 during the loading process, the model maintains fusion with 'weights_proj'. The implementation includes a new helper function, '_try_load_fp8_indexer_wk', and a buffering system to synchronize weights and scales. Feedback was provided regarding the potential for memory leaks in the buffering logic if the loading process is interrupted or if keys are missing from the checkpoint.
| if "weight" not in entry or "scale" not in entry: | ||
| return True # still waiting for the other param |
There was a problem hiding this comment.
There was a problem hiding this comment.
They would go out of scope once the weight loading function returns. The references are not stored on self.
This is the multi-stream baseline I compared against. It seems like the multi-streaming with the FP8 kernels is disabling torch.compile and undoing one of the elementwise op fusions, causing a slowdown. Could you explain what/why you want to run that I have not compared here? |
I am hoping that for the original model ckpt we want to let it run as its original precision, at least on one platform. Changing the precision will cause unknown problem later on and hard to trace back the bug. |
|
This pull request has merge conflicts that must be resolved before it can be |
|
I'm of the opinion that upcasting for the fusion is worth the trouble. The perf improvement seems nontrivial, and it does help avoid having to support both multi-stream and fusion for this op. @robertgshaw2-redhat to weight in |
|
If they really care about perf they should run the nvfp4 checkpoint. I feel like we want to have accuracy guarantee before perf on the OG checkpoint. |
|
There are many cases where we still care about perf for non-NVFP4 checkpoints. And as I understand, upcasting from FP8 to BF16 shouldn't hurt accuracy if done properly |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Reran on 4xGB200, same perf results and GSM8k is good Details
4xGB200, TP4, DSV3.2 FP8, Fused WKGSM8k:1k1k BS1:1k1k BS8:4xGB200, TP4, DSV3.2 FP8, No-Fused-WKGSM8k:1k1k BS1:1k1k BS8: |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
I also manually inspected a few of the FP8 weight scales for the WK layers. They're all fairly small factors (e.g. 0.0002), so we should be staying comfortably in the range where BF16 is well-represented. |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
**Commit range:** `6f786f2`..`d886c26` 1. Fix 'DPMetadata' object has no attribute 'max_tokens_across_dp_cpu' by vllm-project/vllm#39107 2. Fix 'Indexer' object has no attribute 'wk' by vllm-project/vllm#38928 3. Fix 'float' object has no attribute 'language_model' by vllm-project/vllm#39240 ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.19.0 - vLLM main: vllm-project/vllm@6f786f2 --------- Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Signed-off-by: Meihan-chen <zr010426ztt@outlook.com>
**Commit range:** `6f786f2`..`d886c26` 1. Fix 'DPMetadata' object has no attribute 'max_tokens_across_dp_cpu' by vllm-project/vllm#39107 2. Fix 'Indexer' object has no attribute 'wk' by vllm-project/vllm#38928 3. Fix 'float' object has no attribute 'language_model' by vllm-project/vllm#39240 ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.19.0 - vLLM main: vllm-project/vllm@6f786f2 --------- Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com> Signed-off-by: Meihan-chen <zr010426ztt@outlook.com>
Purpose
Alternative fix to #38870 which maintains the fusion.
Performance:
Testing
My setup is broken, GSM8k is giving 0.00 for me even with #38870. Will try to fix my setup and rerun, but have moderate confidence in this fix. Would be handy if someone else could try running this in the meantime.
I'm okay with merging #38870 if we need to fix ASAP this weekend.