Move online quantization to model.load_weights#26327
Move online quantization to model.load_weights#26327vllm-bot merged 1 commit intovllm-project:mainfrom
model.load_weights#26327Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the online quantization logic by moving it from default_loader.load_weights into a decorator, @support_quantized_model_reload_from_hp_weights. This decorator is then applied to the load_weights method of several models, making it easier to use online quantization in RL frameworks. The changes are well-structured and align with the goal of the PR. I've found one potential issue in an edge case that could lead to a crash, and I've provided a suggestion to fix it.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
26d021b to
0282b80
Compare
|
Documentation preview: https://vllm--26327.org.readthedocs.build/en/26327/ |
0282b80 to
cc59c24
Compare
cc59c24 to
a37cd0b
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
a37cd0b to
31b9b36
Compare
3334501 to
384b75f
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
553a484 to
0aad5ee
Compare
|
looks like adding |
f84986f to
dbb9e75
Compare
Summary: Previously we added online quantization to `default_loader.load_weights` but this is not how it's used in RL frameworks, since they will call `model.load_weights(updated_weights)`, so in this PR we refactored the online quantization logic to a decorator `support_quantized_model_reload_from_hp_weights` and use that to decorate `model.load_weights` Test Plan: pytest tests/quantization/test_torchao.py -k test_reload_weights Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
dbb9e75 to
744e334
Compare
|
the failure examples-test is not relevant to the PR. found the same test failure with same error message in other CI jobs: https://buildkite.com/vllm/ci/builds/39560/steps/canvas?sid=019a98a9-9269-49b3-bf96-645dbd38acad and main: https://buildkite.com/vllm/ci/builds/39560/steps/canvas?jid=019a98a9-93ea-4f17-a36b-d223cefd61a1 ready to merge I think |
|
Nice cleanup putting it in the load_weights |
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351
**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351
* Enable FP8 + RL training for bf16 models **Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * _get_inference_mode_context_manager * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* Enable FP8 + RL training for bf16 models (#3440) * Enable FP8 + RL training for bf16 models **Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage: - We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16 - We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel - For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet (this is in progress: vllm-project/vllm#26327) **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = True, # set this to True ) \# the rest is the same as before model = FastLanguageModel.get_peft_model(...) ``` **Initial results:** ``` \# fp8 {'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01} \# bf16 {'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01} ``` <img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" /> Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423 **Requires:** - pytorch/ao#3158 (torchao nightly or 0.15.0+) - unslothai/unsloth-zoo#351 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * _get_inference_mode_context_manager * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> * Update __init__.py * Fix/save torchao model loading logic (#3621) * make loading gpt-oss-BF16 faster. Linked to unsloth-zoo PR #314 * fix model loading and clean merged model directory * revert default quant * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert mapper.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Update loader_utils.py * Update loader_utils.py * Add 128x128 PerBlock FP8 + RL (#3629) * Add 128x128 PerBlock FP8 + RL **Summary:** Following #3440, this PR extends torchao FP8 + RL support to also handle 128x128 PerBlock granularity (in addition to PerRow). **Example usage:** ``` model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3-8B-Base", max_seq_length = 2048, load_in_4bit = False, fast_inference = True, max_lora_rank = 32, load_in_fp8 = "block", # or "row" or True ) ``` **Initial results:** TBD **Note:** - Requires pytorch/ao#3370 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Version * Update vision.py * Update rl.py * Add torch 2.9.1 * Fix auto installer * Update fp8.py * Float8 * Update fp8.py * Update mapper.py * Update mapper.py * Update loader_utils.py * Update loader.py * Update fp8.py * Versioning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: andrewor14 <andrewor14@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com>
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
### What does this PR do? Add support for torchao online quantization for vllm, configure the type of quantization by serializing a config file. See vllm PR for how to generate the quantization file (vllm-project/vllm#23014). Requires vllm changes: vllm-project/vllm#23014 and vllm-project/vllm#26327 ### Test 1. generate the torchao config file (can change to other configs: https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) torchao_config.json ``` from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow from torchao.core.config import config_to_dict import json config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) with open("torchao_config.json", "w") as f: f.write(json.dumps(config_to_dict(config))) # LLM(..., quantization="torchao", hf_overrides={"quantization_config_file": "torchao_config.json"}) ``` this is fp8 dynamic quant: ``` {"_type": "Float8DynamicActivationFloat8WeightConfig", "_version": 2, "_data": {"activation_dtype": {"_type": "torch.dtype", "_data": "float8_e4m3fn"}, "weight_dtype": {"_type": "torch.dtype", "_data": "float8_e4m3fn"}, "granularity": [{"_type": "PerRow", "_version": 1, "_data": {}}, {"_type": "PerRow", "_version": 1, "_data": {}}], "mm_config": {"_type": "Float8MMConfig", "_version": 1, "_data": {"emulate": false, "use_fast_accum": true, "pad_inner_dim": false}}, "activation_value_lb": null, "activation_value_ub": null, "kernel_preference": {"_type": "KernelPreference", "_data": "AUTO"}, "set_inductor_config": true}} ``` 2. Add following to `sh examples/ppo_trainer/run_deepseek7b_llm.sh` ``` actor_rollout_ref.rollout.quantization=torchao \ actor_rollout_ref.rollout.quantization_config_file=torchao_config.json \ ``` 3. Run test VLLM_DISABLE_COMPILE_CACHE=1 sh examples/ppo_trainer/run_deepseek7b_llm.sh ``` # baseline (TaskRunner pid=539843) ("Initial validation metrics: {'val-aux/openai/gsm8k/reward/mean@1': " (TaskRunner pid=539843) "0.6717210007581501, 'val-core/openai/gsm8k/acc/mean@1': 0.6717210007581501, " (TaskRunner pid=539843) "'val-aux/num_turns/min': 2, 'val-aux/num_turns/max': 2, " (TaskRunner pid=539843) "'val-aux/num_turns/mean': 2.0}") (TaskRunner pid=539843) step:105 - val-aux/openai/gsm8k/reward/mean@1:0.6717210007581501 - val-core/openai/gsm8k/acc/mean@1:0.6717210007581501 - val-aux/num_turns/min:2 - val-aux/num_turns/max:2 - val-aux/num_turns/mean:2.0 # fp8 (TaskRunner pid=3763210) validation generation end (TaskRunner pid=3763210) ("Initial validation metrics: {'val-aux/openai/gsm8k/reward/mean@1': " (TaskRunner pid=3763210) "0.6739954510993177, 'val-core/openai/gsm8k/acc/mean@1': 0.6739954510993177, " (TaskRunner pid=3763210) "'val-aux/num_turns/min': 2, 'val-aux/num_turns/max': 2, " (TaskRunner pid=3763210) "'val-aux/num_turns/mean': 2.0}") (TaskRunner pid=3763210) step:105 - val-aux/openai/gsm8k/reward/mean@1:0.6739954510993177 - val-core/openai/gsm8k/acc/mean@1:0.6739954510993177 - val-aux/num_turns/min:2 - val-aux/num_turns/max:2 - val-aux/num_turns/mean:2.0 ``` Docs: no docs added yet since I didn't find a place to add quantized rollout docs in https://github.com/volcengine/verl/blob/main/docs/workers/fsdp_workers.rst, happy to add later when there are more docs We can add simple string options (e.g. fp8_tensorwise, fp8_rowwise, fp8_blockwise etc.) in the future if needed. Reviewers: Subscribers: Tasks: Tags: ### Checklist Before Starting - [x] Search for similar PRs. https://github.com/volcengine/verl/pulls?q=sort%3Aupdated-desc+is%3Apr+is%3Aopen+quantization+ - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
Summary:
Previously we added online quantization to
default_loader.load_weightsbut this is not how it's used in RL frameworks, since they will callmodel.load_weights(updated_weights), so in this PR we refactored the online quantization logic to a decoratorsupport_quantized_model_reload_from_hp_weightsand use that to decorateAutoWeightLoader.load_weightsTest Plan:
Reviewers:
Subscribers:
Tasks:
Tags: