Do online fp8 quantization while loading weights instead of in process_weights_after_loading, reducing memory overhead#17945
Conversation
Summary of ChangesHello @fxmarty-amd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization for FP8 quantization by implementing an "online" quantization strategy. Instead of loading all model weights in a higher precision (bf16) and then quantizing them in a separate post-loading step, the weights are now quantized directly as they are loaded. This change is crucial for reducing peak GPU memory usage, preventing potential Out-Of-Memory errors, and improving the efficiency of loading large language models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the fp8 online quantization logic to perform quantization during weight loading, which is a great memory optimization. The implementation looks mostly correct, but I've found a couple of issues. There's a redundant line in the new weight loader that should be removed for clarity. More importantly, I've identified a potential bug in process_weights_after_loading where the weight tensor is transposed, which seems to lead to a shape mismatch in the subsequent matrix multiplication. Please see my detailed comments.
| if not self.quant_config.is_checkpoint_fp8_serialized and _use_hip_int4: | ||
| raise NotImplementedError( | ||
| f"Online MOE FP8 quantization (is_checkpoint_fp8_serialized={self.quant_config.is_checkpoint_fp8_serialized}) along SGLANG_INT4_WEIGHT=1 is not supported at the moment. Please open an issue." | ||
| ) |
There was a problem hiding this comment.
Not supported on main branch either.
fp8 quantization while loading weights instead of in process_weights_after_loadingfp8 quantization while loading weights instead of in process_weights_after_loading,
fp8 quantization while loading weights instead of in process_weights_after_loading,fp8 quantization while loading weights instead of in process_weights_after_loading, reducing memory overhead
|
cc @HaiShaw can you have a look? |
|
Hi @kkHuang-amd @HaiShaw what do you think? Happy to address comments and fix conflicts accordingly. |
As per title.
The current implementation of
fp8online quantization first initializes and loads all weights in bf16, and only after quantizes them inprocess_weights_after_loading. This is inefficient in terms of GPU memory, and may lead to OOM during loading, although the quantized FP8 model should fit in memory.This PR moves to doing online quantization in the weight loader, similar to #7392.
Fixes #2895
Fixes #8337
Left to do before merge:
Fp8MoEMethodmain), for the test Qwen/Qwen2.5-1.5B-Instruct model (and also qwen3 8B) => is it expected? or bugged?