Skip to content

[skyrl-train] Megatron LoRA#743

Merged
erictang000 merged 35 commits intoNovaSky-AI:mainfrom
erictang000:megatron_lora
Dec 28, 2025
Merged

[skyrl-train] Megatron LoRA#743
erictang000 merged 35 commits intoNovaSky-AI:mainfrom
erictang000:megatron_lora

Conversation

@erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Dec 6, 2025

Enables LoRA training with the Megatron Backend. Currently waiting for NVIDIA-NeMo/Megatron-Bridge#1762 to be merged into main, so we can at least pin a commit rather than a branch for stability.

  • Adds LoRA support via Megatron-Bridge
  • Adds custom checkpointing for LoRA model parameters (until LoRA checkpointing logic is upstreamed to Megatron-Bridge).
  • Weight syncing logic for Megatron + LoRA is handled by merging the LoRA parameters back into the base model before exporting to vLLM. This means that for megatron lora (for now), lora does not have to be configured for vLLM.

Examples

GSM8K for Qwen3-30B-MoE and Qwen3-0.6B converging:
image

  • Qwen3-30B-A3B previously required 2 H100 nodes for full parameter fine tuning - we can increase batch size compared to previous runs with LoRA on just 1 H100 node!

DAPO Qwen-4B

With TIS - megatron dense backend can match/exceed FSDP backend perf. TIS is especially important for the current version of LoRA. Canonical Lora seems to be less good than "performant lora" - or maybe more sensitive to learning rate.
image

Blockers/TODOs:

Future Work

  • Once Megatron-Bridge support for exporting only lora parameters is done, we should support just syncing these to vLLM for lower communication cost
  • Add support for other LoRA variants from Megatron-Bridge (canonical lora, qlora, dora).

@erictang000 erictang000 marked this pull request as draft December 6, 2025 01:48
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces LoRA (Low-Rank Adaptation) support for Megatron, which is a significant feature for efficient model fine-tuning. The changes span across Docker configuration, project dependencies, and the core Megatron worker implementation. The implementation correctly adds LoRA configuration and applies it via a pre-wrap hook. The tests have also been updated to cover the new LoRA functionality. My review includes a few suggestions to improve the Dockerfile efficiency and minor code style fixes.

Comment on lines +23 to +28
RUN sudo apt-get update && sudo apt-get install -y \
cmake \
ninja-build \
libnccl-dev=2.25.1-1+cuda12.8 \
libnccl2=2.25.1-1+cuda12.8 \
&& sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To optimize the Docker image size and build time, it's a best practice to consolidate apt-get commands into fewer RUN layers. This RUN instruction could be combined with other apt-get install commands in this Dockerfile to reduce the number of layers in the final image.

Comment on lines +41 to +47
RUN pip install --no-cache-dir \
nvidia-mathdx \
pybind11 \
setuptools \
wheel

RUN pip install --no-cache-dir --no-build-isolation "transformer_engine[pytorch]==2.9.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To reduce the number of layers in the Docker image and improve build efficiency, it's better to chain the pip install commands within a single RUN instruction using &&.

RUN pip install --no-cache-dir \
    nvidia-mathdx \
    pybind11 \
    setuptools \
    wheel && \
    pip install --no-cache-dir --no-build-isolation "transformer_engine[pytorch]==2.9.0"


return lora_model
self.provider.register_pre_wrap_hook(lora_pre_wrap_hook)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is trailing whitespace on this line that should be removed to maintain code cleanliness.

Suggested change
self.provider.register_pre_wrap_hook(lora_pre_wrap_hook)

Comment on lines 308 to 309


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an extra blank line here. Please remove it to improve code readability and maintain a consistent style.

@erictang000 erictang000 changed the title [skyrl-train] WIP Megatron LoRA [skyrl-train] Megatron LoRA Dec 21, 2025
@erictang000 erictang000 marked this pull request as ready for review December 21, 2025 00:48
@erictang000
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant new functionality by adding LoRA support for Megatron, including new example scripts, core logic changes for distributed training strategies and workers, and updated tests. The implementation for LoRA support appears solid. My review focuses on improving the robustness and clarity of the new example scripts and project configuration. I've identified a critical issue with undefined variables in one script that will cause it to fail, and several medium-severity issues related to configuration consistency, dependency management, and script robustness.

{ url = "https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", marker = "extra == 'sglang' and extra != 'mcore' and extra != 'vllm'" }
]
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "22ef9ff9f9684ba2f2dbea14db974f5c31bbd683"}
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", branch = "feature/peft-recompute-hook"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pinning a dependency to a git branch can lead to non-reproducible builds and instability, as the branch can be updated or deleted. It's better to pin to a specific commit hash from that branch to ensure stability. The PR description mentions that this is awaiting a merge, which has now happened. Please update this to a specific commit hash or a released version.

Suggested change
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", branch = "feature/peft-recompute-hook"}
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "<commit_hash_from_feature_branch>"}

Comment on lines +1 to +15
SCRIPT_PATH=$1
LOG_FILE=$2
state=error
error_string="unexpected system error"
echo "Clearing log file first $LOG_FILE"
> $LOG_FILE
while [ $state == 'error' ]; do
echo "(Re)Running the script $SCRIPT_PATH"
bash $SCRIPT_PATH >> $LOG_FILE 2>&1
# script returned, let's check if there was an error
last_lines=$(tail -n 150 $LOG_FILE)
if [[ ! "$last_lines" =~ "$error_string" ]];then
state=noterror
fi
done No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This script is not robust against file paths with spaces. Variables like $SCRIPT_PATH and $LOG_FILE should be quoted when used to prevent word splitting. Also, the while loop condition should have its variable quoted.

For example:

  • SCRIPT_PATH="$1"
  • > "$LOG_FILE"
  • while [ "$state" == "error" ] or while [[ "$state" == "error" ]]
  • bash "$SCRIPT_PATH" >> "$LOG_FILE" 2>&1
Suggested change
SCRIPT_PATH=$1
LOG_FILE=$2
state=error
error_string="unexpected system error"
echo "Clearing log file first $LOG_FILE"
> $LOG_FILE
while [ $state == 'error' ]; do
echo "(Re)Running the script $SCRIPT_PATH"
bash $SCRIPT_PATH >> $LOG_FILE 2>&1
# script returned, let's check if there was an error
last_lines=$(tail -n 150 $LOG_FILE)
if [[ ! "$last_lines" =~ "$error_string" ]];then
state=noterror
fi
done
SCRIPT_PATH="$1"
LOG_FILE="$2"
state=error
error_string="unexpected system error"
echo "Clearing log file first $LOG_FILE"
> "$LOG_FILE"
while [[ "$state" == "error" ]]; do
echo "(Re)Running the script $SCRIPT_PATH"
bash "$SCRIPT_PATH" >> "$LOG_FILE" 2>&1
# script returned, let's check if there was an error
last_lines=$(tail -n 150 "$LOG_FILE")
if [[ ! "$last_lines" =~ "$error_string" ]];then
state=noterror
fi
done

@erictang000
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces LoRA training support for the Megatron backend, which is a significant enhancement. The changes are comprehensive, touching upon model configuration, distributed strategy, worker implementation, and testing. The custom checkpointing for LoRA adapters is well-implemented, with clear logic for saving and loading adapter weights separately. The approach of merging LoRA parameters into the base model for vLLM compatibility is a practical solution for now.

I've identified a couple of areas for improvement:

  • In skyrl-train/skyrl_train/workers/megatron/megatron_worker.py, there's an opportunity to refactor the configure_lora method to reduce code duplication.
  • The new test test_megatron_lora_forward in skyrl-train/tests/gpu/test_megatron_worker.py seems incomplete as it calculates a difference but doesn't assert its value.

Overall, this is a great contribution. Addressing these minor points will improve the code's maintainability and test coverage.

Comment on lines +414 to +415
max_diff = torch.max(torch.abs(action_log_probs_full - action_log_probs_lora))
print(f"Max diff: {max_diff}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this test states that it checks for numerical equivalence, but there is no assertion to verify this. Since LoRA B weights are initialized to zero, the forward pass of the LoRA model should be identical to the base model before any training. Please add an assertion to confirm that the difference is zero or very close to it.

Suggested change
max_diff = torch.max(torch.abs(action_log_probs_full - action_log_probs_lora))
print(f"Max diff: {max_diff}")
max_diff = torch.max(torch.abs(action_log_probs_full - action_log_probs_lora))
print(f"Max diff: {max_diff}")
assert torch.allclose(action_log_probs_full, action_log_probs_lora, atol=1e-6)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this comment was resolved @erictang000

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_diff = torch.max(torch.abs(action_log_probs_full - action_log_probs_lora))

fixed it later

@erictang000 erictang000 merged commit 901fc9c into NovaSky-AI:main Dec 28, 2025
3 checks passed
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
Enables LoRA training with the Megatron Backend. Currently waiting for
NVIDIA-NeMo/Megatron-Bridge#1762 to be merged
into main, so we can at least pin a commit rather than a branch for
stability.

- Adds
[LoRA](https://docs.nvidia.com/nemo/megatron-bridge/0.2.0/apidocs/bridge/bridge.peft.lora.html)
support via Megatron-Bridge
- Adds custom checkpointing for LoRA model parameters (until LoRA
checkpointing logic is upstreamed to Megatron-Bridge).
- Weight syncing logic for Megatron + LoRA is handled by merging the
LoRA parameters back into the base model before exporting to vLLM. This
means that for megatron lora (for now), lora does not have to be
configured for vLLM.

## Examples

GSM8K for Qwen3-30B-MoE and Qwen3-0.6B converging:
<img width="1087" height="808" alt="image"
src="https://github.com/user-attachments/assets/95e03b75-4a8c-4734-8f55-2cf535b04876"
/>

- Qwen3-30B-A3B previously required 2 H100 nodes for full parameter fine
tuning - we can increase batch size compared to previous runs with LoRA
on just 1 H100 node!

### DAPO Qwen-4B
With TIS - megatron dense backend can match/exceed FSDP backend perf.
TIS is especially important for the current version of LoRA. Canonical
Lora seems to be less good than "performant lora" - or maybe more
sensitive to learning rate.
<img width="1214" height="814" alt="image"
src="https://github.com/user-attachments/assets/4c2d2b37-f835-4e53-ac54-7e54812b6006"
/>


Blockers/TODOs:
- [x] ~~For Dense models, LoRA results in low grad norm/0 ppo_clip_ratio
unless pp > 1. Something on megatron-core or megatron-bridge is broken
for dense models.~~ Issue tracked on Megatron-Bridge
(NVIDIA-NeMo/Megatron-Bridge#1750), awaiting
PR NVIDIA-NeMo/Megatron-Bridge#1762
- [x] Test out MoE models

## Future Work
- Once Megatron-Bridge support for exporting only lora parameters is
done, we should support just syncing these to vLLM for lower
communication cost
- Add support for other LoRA variants from Megatron-Bridge (canonical
lora, qlora, dora).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants