[fsdp] feat: Merge lora in fsdp training to speed up rollout#5115
[fsdp] feat: Merge lora in fsdp training to speed up rollout#5115ISEEKYAN merged 7 commits intoverl-project:mainfrom
Conversation
3418cce to
4f5fb1a
Compare
|
when training qwen2.5_3b on gsm8k data 1 epoch training_config.yaml before fixing:
after fixing:
|
4f5fb1a to
79cff86
Compare
HollowMan6
left a comment
There was a problem hiding this comment.
Could you please resolve the merge conflict in this PR?
HollowMan6
left a comment
There was a problem hiding this comment.
In addition, maybe it's a good time to reconsider putting all the lora related settings under the .lora key and deprecate any effective lora settings outside .lora at this stage. Right now, .lora is considered to be megatron-only, but since this PR introduces merge for FSDP backend as well, it will become a mix-in which will be generally confusing.
Could you also update the docs at docs/advance/ppo_lora.rst for this feature?
agree it is cleaner to put all lora settings under For the doc, yes, will update |
c42c704 to
0ebe0e2
Compare
0ebe0e2 to
730c451
Compare
There was a problem hiding this comment.
Pull request overview
This PR implements LoRA weight merging during FSDP training to improve rollout throughput and fixes a critical padding alignment bug affecting log probability calculations.
Changes:
- Adds LoRA merge functionality that temporarily merges LoRA adapters into base model weights before syncing to vLLM, eliminating adapter computation overhead during inference
- Fixes padding bug in
no_padding_2_paddingthat caused misalignment between ref_log_probs/old_log_probs and current log_probs, resulting in incorrect ratio calculations (~451,728 instead of ~1.0) - Adds comprehensive test coverage for padding fixes and FSDP LoRA merge operations
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| verl/workers/utils/padding.py | Complete rewrite of no_padding_2_padding to correctly slice response regions from unpadded model output, fixing log probability alignment bug |
| verl/workers/utils/losses.py | Updated to use refactored no_padding_2_padding function for PPO loss computation |
| verl/workers/rollout/vllm_rollout/vllm_async_server.py | Refactored LoRA configuration logic to support merge mode and improved conditional flow |
| verl/workers/rollout/vllm_rollout/utils.py | Added fallback compatibility for different vLLM versions |
| verl/workers/engine/fsdp/transformer_impl.py | Implemented merged LoRA context for weight synchronization, added support for target_parameters, converted peft_config to dict for vLLM compatibility |
| verl/workers/config/model.py | Added target_parameters field for LoRA adapter on nn.Parameter |
| verl/utils/io_utils.py | New utility file for I/O operations (not used in this PR) |
| verl/utils/fsdp_utils.py | Added functions for LoRA merging/unmerging, parameter name normalization, and backup/restore operations |
| verl/utils/config.py | Reordered lora_rank determination logic for consistency |
| verl/protocol.py | Added print_summary method for DataProto debugging |
| tests/utils/test_padding_on_cpu.py | Comprehensive tests for padding conversion fix |
| tests/utils/test_normalize_peft_param_name*.py | Tests for PEFT parameter name normalization |
| tests/utils/test_fsdp_lora_merge.py | Tests for FSDP LoRA merge context manager |
| tests/special_sanity/check_license.py | Added license header for 2026 Amazon copyright |
| tests/special_e2e/run_ppo_trainer_megatron.sh | Added fully_sharded_loras configuration |
| examples/grpo_trainer/run_qwen3-4b_gsm8k_grpo_lora_merge.sh | Example script demonstrating LoRA merge usage |
| docs/advance/ppo_lora.rst | Documentation for new LoRA merge feature |
Comments suppressed due to low confidence (1)
verl/utils/fsdp_utils.py:699
- This assignment to 'set_reshard_after_forward' is unnecessary as it is redefined before this value is used.
def set_reshard_after_forward(module: FSDPModule, reshard_after_forward: bool, recurse: bool = True) -> None:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
verl/utils/io_utils.py
Outdated
| # Copyright 2026 Amazon.com Inc and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Utils for IO""" | ||
|
|
||
| import csv | ||
| import json | ||
| import logging | ||
| import os | ||
| import pathlib | ||
| from typing import Any, Literal | ||
|
|
||
| import pandas as pd | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def read_lines(filename: str | pathlib.Path, skip_header: bool = False) -> list[str]: | ||
| """Read lines from the filename into a list and optionally skip the header""" | ||
| with open(filename, encoding="utf-8") as fid: | ||
| if skip_header: | ||
| fid.readline() # skip the header | ||
| lines = fid.read().splitlines() | ||
| result = [line.strip() for line in lines] | ||
| return [x for x in result if x] | ||
|
|
||
|
|
||
| def load_jsonl(filename: str | pathlib.Path) -> list[dict]: | ||
| """Loads jsonl file into a list of Dict objects""" | ||
| lines = read_lines(filename) | ||
| return [json.loads(line) for line in lines] | ||
|
|
||
|
|
||
| def load_json(filename: str | pathlib.Path) -> dict: | ||
| """Load the text file as a json doc""" | ||
| with open(filename, encoding="utf-8") as fid: | ||
| return json.load(fid) | ||
|
|
||
|
|
||
| def save_text(text: str, filename: str | pathlib.Path) -> None: | ||
| """Save a list of json docs into jsonl file""" | ||
| with open(filename, "w", encoding="utf-8") as fid: | ||
| fid.write(text) | ||
|
|
||
|
|
||
| def save_jsonl(docs: list[dict], filename: str | pathlib.Path, mode: Literal["a", "w"] = "w") -> None: | ||
| """Write/Append a list of json docs into jsonl file""" | ||
| if mode not in ["a", "w"]: | ||
| raise AttributeError('mode needs to be one of ["a", "w"]') | ||
| with open(filename, mode, encoding="utf-8") as fid: | ||
| for doc in docs: | ||
| line = json.dumps(doc) | ||
| fid.write(line + "\n") | ||
|
|
||
|
|
||
| def save_json(doc: dict, filename: str | pathlib.Path) -> None: | ||
| """Load the text file as a json doc""" | ||
| with open(filename, "w", encoding="utf-8") as fid: | ||
| json.dump(doc.copy(), fid, indent=2) | ||
|
|
||
|
|
||
| def write_list_to_file(src_list: list[Any], filename: pathlib.Path | str) -> None: | ||
| """Write lines into text file""" | ||
| filename = str(filename) | ||
| with open(filename, "w") as fh: | ||
| for v in src_list: | ||
| fh.write(f"{v}\n") | ||
|
|
||
|
|
||
| def write_text_to_file(text: str, filename: pathlib.Path | str) -> None: | ||
| """Write lines into text file""" | ||
| filename = str(filename) | ||
| create_dir_if_not_exist(filename) | ||
| with open(filename, "w") as fh: | ||
| fh.write(text) | ||
|
|
||
|
|
||
| def save_csv( | ||
| contents: list[dict[str, Any]], | ||
| columns: list[str], | ||
| filename: pathlib.Path | str, | ||
| ) -> None: | ||
| """Save a list of key value pairs into csv file""" | ||
| assert len(columns) > 0 | ||
| with open(filename, "w", newline="") as csvfile: | ||
| writer = csv.DictWriter(csvfile, fieldnames=columns) | ||
| writer.writeheader() | ||
| for row in contents: | ||
| writer.writerow(row) | ||
|
|
||
|
|
||
| def save_csv_columns(contents: dict[str, list], filename: pathlib.Path | str, log: bool = False): | ||
| """Save columnar content to a csv file | ||
| contents format: | ||
| { | ||
| "col_a": [1,2,3], | ||
| "col_b": ['a', 'b', 'c'] | ||
| } | ||
| """ | ||
| df = pd.DataFrame(contents) | ||
| logger.info(f"Save data with {len(df)} rows and {len(df.columns)} columns to csv file {filename}") | ||
| if log: | ||
| logger.info(f"The contents are:\n{df.to_string(index=False)}") | ||
| df.to_csv(filename, index=False) | ||
|
|
||
|
|
||
| def create_dir_if_not_exist(filename: str): | ||
| dir_path = os.path.dirname(filename) | ||
| if not os.path.exists(f"{dir_path}"): | ||
| pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True) |
There was a problem hiding this comment.
This file appears to be added but is not imported or used anywhere in the codebase based on searches. If this file is intended for future use or as a general utility, consider removing it from this PR to keep changes focused on the LoRA merge feature. Alternatively, if it is needed for the feature, add the necessary imports and usage.
There was a problem hiding this comment.
@amzfang Better to remove this file here if it's not used anywhere in this PR.
730c451 to
2b701fe
Compare
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
amzfang
left a comment
There was a problem hiding this comment.
updated based on previous comments @HollowMan6
verl/utils/io_utils.py
Outdated
| # Copyright 2026 Amazon.com Inc and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Utils for IO""" | ||
|
|
||
| import csv | ||
| import json | ||
| import logging | ||
| import os | ||
| import pathlib | ||
| from typing import Any, Literal | ||
|
|
||
| import pandas as pd | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def read_lines(filename: str | pathlib.Path, skip_header: bool = False) -> list[str]: | ||
| """Read lines from the filename into a list and optionally skip the header""" | ||
| with open(filename, encoding="utf-8") as fid: | ||
| if skip_header: | ||
| fid.readline() # skip the header | ||
| lines = fid.read().splitlines() | ||
| result = [line.strip() for line in lines] | ||
| return [x for x in result if x] | ||
|
|
||
|
|
||
| def load_jsonl(filename: str | pathlib.Path) -> list[dict]: | ||
| """Loads jsonl file into a list of Dict objects""" | ||
| lines = read_lines(filename) | ||
| return [json.loads(line) for line in lines] | ||
|
|
||
|
|
||
| def load_json(filename: str | pathlib.Path) -> dict: | ||
| """Load the text file as a json doc""" | ||
| with open(filename, encoding="utf-8") as fid: | ||
| return json.load(fid) | ||
|
|
||
|
|
||
| def save_text(text: str, filename: str | pathlib.Path) -> None: | ||
| """Save a list of json docs into jsonl file""" | ||
| with open(filename, "w", encoding="utf-8") as fid: | ||
| fid.write(text) | ||
|
|
||
|
|
||
| def save_jsonl(docs: list[dict], filename: str | pathlib.Path, mode: Literal["a", "w"] = "w") -> None: | ||
| """Write/Append a list of json docs into jsonl file""" | ||
| if mode not in ["a", "w"]: | ||
| raise AttributeError('mode needs to be one of ["a", "w"]') | ||
| with open(filename, mode, encoding="utf-8") as fid: | ||
| for doc in docs: | ||
| line = json.dumps(doc) | ||
| fid.write(line + "\n") | ||
|
|
||
|
|
||
| def save_json(doc: dict, filename: str | pathlib.Path) -> None: | ||
| """Load the text file as a json doc""" | ||
| with open(filename, "w", encoding="utf-8") as fid: | ||
| json.dump(doc.copy(), fid, indent=2) | ||
|
|
||
|
|
||
| def write_list_to_file(src_list: list[Any], filename: pathlib.Path | str) -> None: | ||
| """Write lines into text file""" | ||
| filename = str(filename) | ||
| with open(filename, "w") as fh: | ||
| for v in src_list: | ||
| fh.write(f"{v}\n") | ||
|
|
||
|
|
||
| def write_text_to_file(text: str, filename: pathlib.Path | str) -> None: | ||
| """Write lines into text file""" | ||
| filename = str(filename) | ||
| create_dir_if_not_exist(filename) | ||
| with open(filename, "w") as fh: | ||
| fh.write(text) | ||
|
|
||
|
|
||
| def save_csv( | ||
| contents: list[dict[str, Any]], | ||
| columns: list[str], | ||
| filename: pathlib.Path | str, | ||
| ) -> None: | ||
| """Save a list of key value pairs into csv file""" | ||
| assert len(columns) > 0 | ||
| with open(filename, "w", newline="") as csvfile: | ||
| writer = csv.DictWriter(csvfile, fieldnames=columns) | ||
| writer.writeheader() | ||
| for row in contents: | ||
| writer.writerow(row) | ||
|
|
||
|
|
||
| def save_csv_columns(contents: dict[str, list], filename: pathlib.Path | str, log: bool = False): | ||
| """Save columnar content to a csv file | ||
| contents format: | ||
| { | ||
| "col_a": [1,2,3], | ||
| "col_b": ['a', 'b', 'c'] | ||
| } | ||
| """ | ||
| df = pd.DataFrame(contents) | ||
| logger.info(f"Save data with {len(df)} rows and {len(df.columns)} columns to csv file {filename}") | ||
| if log: | ||
| logger.info(f"The contents are:\n{df.to_string(index=False)}") | ||
| df.to_csv(filename, index=False) | ||
|
|
||
|
|
||
| def create_dir_if_not_exist(filename: str): | ||
| dir_path = os.path.dirname(filename) | ||
| if not os.path.exists(f"{dir_path}"): | ||
| pathlib.Path(dir_path).mkdir(parents=True, exist_ok=True) |
There was a problem hiding this comment.
@amzfang Better to remove this file here if it's not used anywhere in this PR.
|
BTW, it would be appreciated if you could add the figure you mentioned to the |
e79309d to
f5914e7
Compare
Thanks! The workflow is fairly linear and can be explained in the text. I updated the description to avoid any confusion. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
HollowMan6
left a comment
There was a problem hiding this comment.
I think we are good to go expect a minor issue in the training script.
Though currently we don't have a CI for testing this feature, it would be appreciated if you could help add one, or you could also let me know if you have manually tested the latest code on your side with merge either on or off. Once all the other CI passes, I think I will have no other concerns
| actor_rollout_ref.model.path=Qwen/Qwen3-4B \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| +actor_rollout_ref.model.lora.merge=True \ |
There was a problem hiding this comment.
I don't think we need + here as merge is already a existing field, right?
|
@HollowMan6 megatron also support merge=True and merge=False, right? do we need also support the two modes in megatron? |
Yes, for megatron, both |
@amzfang This is not a bug, import torch
from flash_attn.bert_padding import pad_input, unpad_input
max_seq_len, max_response_len = 5, 3
input_ids = torch.tensor(
[[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]],
dtype=torch.long
)
attention_mask = torch.tensor(
[[0, 1, 1, 1, 1],
[1, 1, 1, 0, 0]],
dtype=torch.long
)
old_log_probs = torch.nested.as_nested_tensor(
[torch.tensor([0.1, 0.2, 0.3, 0.4]),
torch.tensor([0.5, 0.6, 0.7])],
layout=torch.jagged
)
input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
full_values = pad_input(
hidden_states=old_log_probs.values().unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=max_seq_len,
)
full_values.squeeze(-1)
# output:
# tensor([[0.0000, 0.1000, 0.2000, 0.3000, 0.4000],
# [0.5000, 0.6000, 0.7000, 0.0000, 0.0000]]) |
|
|
Can you please check the unit tests for the old implementation? It failed before. I don't have access to my old env anymore and will try it later. - additionally it would be good to keep the padding implementation the same for ref-log-probs, old-log-probs and current log-probs. |
@fangliuwh Yes, I can reproduce failed before this commit. The problem is that in the unit test, input_ids is left padding, while in agent loop, the input_ids(=prompts+responses) is left-right padding: Nevertheless, I agree that we should keep padding implementation the same for ref-log-probs, old-log-probs and current log-probs. I just want to find why the performance metrics is not as expected in your case. |
|
I guess we should revert no_padding_2_padding implementation because it changes the key of the data structure @wuxibin89 @amzfang |
|
Our implementation assumes left-right padding |
…oject#5115) ### What does this PR do? In current lora RL training workflows, engines like vLLM often utilize unmerged LoRA inference, which incurs a latency penalty relative to full fine-tuned models. While enabling CUDA graphs can dampen this overhead, merging LoRA weights directly into the base model during rollout can further maximize inference throughput, achieving base model efficiency by eliminating specialized adapter computation overhead. Beyond performance, weight merging serves as a critical workaround for model families that lack native unmerged LoRA support in vLLM, bypassing framework limitations and accelerating development without waiting for upstream software updates. During the feature development, this PR fixed a bug in misalignment between ref_log_probs/old_log_probs vs current log probs when using new engine_worker for training (that is, `trainer.use_legacy_worker_impl=disable`) The bug is that `no_padding_2_padding` always slices [:, -max_response_len-1:-1] from the padded tensor which does not pad the prompts on left side. The correction is to keep the unpadding logic the same for ref-log-probs/old-log-probs and current log-probs. This problem occurs not only for lora training but also for full parameter training. Refer to `verl/workers/utils/padding.py` for details. See detailed results in comment below. ### Checklist Before Starting - [ x] Search for similar PRs. Paste at least one query link here: https://github.com/verl-project/verl/issues?q=is%3Aissue%20state%3Aopen%20lora%20slow - [x ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. unit test: - tests/utils/test_fsdp_lora_merge.py - tests/utils/test_normalize_peft_param_name.py - tests/utils/test_normalize_peft_param_name_on_cpu.py - tests/utils/test_padding_on_cpu.py (for padding fix) manual test: - examples/grpo_trainer/run_qwen3-4b_gsm8k_grpo_lora_merge.sh ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. as shown in `examples/grpo_trainer/run_qwen3-4b_gsm8k_grpo_lora_merge.sh` ```python # Add code snippet or script demonstrating how to use this python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ ... +actor_rollout_ref.model.lora.merge=True \ actor_rollout_ref.model.lora_rank=32 \ actor_rollout_ref.model.lora_alpha=64 \ trainer.use_legacy_worker_impl=disable \ ... ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. Merging lora to speed up rollout We optimize the rollout efficiency for fsdp training backend through the following workflow: 1. offload the model weights to cpu 2. temporally merging the Lora adapter weights to base model before rollout 3. synchronize the merged model weights from fsdp to vllm via NCCL 4. restore the model weights from cpu 5. perform rollout on merged model in vllm without any additional computation caused by lora adapter The key challenge here is how to run these steps as fast and memory-efficient as possible, leveraging the GPU’s fast computation and communication mechanisms. Specifically, Lora adapter together with the base model weights are sharded across GPUs during FSDP training, so we need to unshard (i.e. gather) the parameters of Lora adapter and base model before merging Lora into base model, preferably on GPUs. Due to GPU memory limit, we cannot unshard the full model at once for large models (e.g. >40B) etc. Our solution only merges lora weights layer-by-layer to keep low memory footprint, similarly to what FSDP does in forward/backward functions. For each model layer, we first unshard the model and lora; merge lora into base model and then reshard the merged model to release the GPU memory required for full model layer. After merging is done, we can leverage NCCL to sync model weights to vllm. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>


What does this PR do?
In current lora RL training workflows, engines like vLLM often utilize unmerged LoRA inference, which incurs a latency penalty relative to full fine-tuned models. While enabling CUDA graphs can dampen this overhead, merging LoRA weights directly into the base model during rollout can further maximize inference throughput, achieving base model efficiency by eliminating specialized adapter computation overhead. Beyond performance, weight merging serves as a critical workaround for model families that lack native unmerged LoRA support in vLLM, bypassing framework limitations and accelerating development without waiting for upstream software updates.
During the feature development, this PR fixed a bug in misalignment between ref_log_probs/old_log_probs vs current log probs when using new engine_worker for training (that is,
trainer.use_legacy_worker_impl=disable)The bug is that
no_padding_2_paddingalways slices [:, -max_response_len-1:-1] from the padded tensor which does not pad the prompts on left side. The correction is to keep the unpadding logic the same for ref-log-probs/old-log-probs and current log-probs. This problem occurs not only for lora training but also for full parameter training. Refer toverl/workers/utils/padding.pyfor details. See detailed results in comment below.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
unit test:
manual test:
API and Usage Example
as shown in
examples/grpo_trainer/run_qwen3-4b_gsm8k_grpo_lora_merge.shDesign & Code Changes
Merging lora to speed up rollout
We optimize the rollout efficiency for fsdp training backend through the following workflow:
The key challenge here is how to run these steps as fast and memory-efficient as possible, leveraging the GPU’s fast computation and communication mechanisms. Specifically, Lora adapter together with the base model weights are sharded across GPUs during FSDP training, so we need to unshard (i.e. gather) the parameters of Lora adapter and base model before merging Lora into base model, preferably on GPUs. Due to GPU memory limit, we cannot unshard the full model at once for large models (e.g. >40B) etc.
Our solution only merges lora weights layer-by-layer to keep low memory footprint, similarly to what FSDP does in forward/backward functions. For each model layer, we first unshard the model and lora; merge lora into base model and then reshard the merged model to release the GPU memory required for full model layer. After merging is done, we can leverage NCCL to sync model weights to vllm.
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.