Skip to content

[GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS#3073

Merged
winglian merged 14 commits into
mainfrom
merge-fsdp
Aug 16, 2025
Merged

[GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS#3073
winglian merged 14 commits into
mainfrom
merge-fsdp

Conversation

@winglian
Copy link
Copy Markdown
Collaborator

@winglian winglian commented Aug 15, 2025

Summary by CodeRabbit

  • New Features

    • Automatically merges sharded FSDP weights into the output directory after training, with a fallback merged path when disk space is low.
    • More robust CLI merge workflow that locates latest checkpoints and reports merge results.
  • Bug Fixes

    • Removes FSDP prefixes from saved model architecture entries to ensure correct model identification.
  • Documentation

    • Expanded GPT-OSS README: disk-space guidance, merge workflow, inference tips (SGLang/vLLM notes), dataset and multi-GPU training tips, and resources.
  • Tests

    • Added tests validating checkpoint discovery and auto-resume behavior.
  • Chores

    • Example config now keeps only the last two checkpoints.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Aug 15, 2025

📝 Walkthrough

Walkthrough

Adds checkpoint discovery (determine_last_checkpoint), updates training to auto-merge sharded FSDP weights and clean config.json, enhances the FSDP merge CLI with fallback checkpoint detection and synchronized logging, updates GPT-OSS example docs/config for disk/merge/inference guidance, and removes two GRPO trainers from public exports.

Changes

Cohort / File(s) Summary
Training checkpoint & utils
src/axolotl/train.py, src/axolotl/utils/train.py, tests/utils/test_train.py
Introduces determine_last_checkpoint(cfg, update: bool=True) (moved to new utils module), replaces prior resume logic, adds tests, and updates training flow to call it. Adds logic to auto-merge SHARDED_STATE_DICT FSDP weights post-save and strip "FSDP" prefixes from model config.json.
FSDP merge CLI
src/axolotl/cli/merge_sharded_fsdp_weights.py
Adds fallback to locate the FSDP checkpoint (using determine_last_checkpoint), uses explicit output_dir/merged output path, moves barrier/synchronization to CLI with PartialState, improves logging, and tightens error handling/imports.
Trainer exports
src/axolotl/core/trainers/__init__.py
Removes public exports AxolotlGRPOSequenceParallelTrainer and AxolotlGRPOTrainer.
GPT-OSS examples & docs
examples/gpt-oss/README.md, examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
README additions: disk-space guidance, FSDP prefix errata and sed fix, merge workflow, inference/server guidance, dataset/tooling tips and links. YAML: adds save_total_limit: 2 to limit retained checkpoints.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • SalmanMohammadi

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch merge-fsdp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Aug 15, 2025

📖 Documentation Preview: https://689fcd4a5fa24d1f463adc44--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 352a30d

@codecov
Copy link
Copy Markdown

codecov Bot commented Aug 15, 2025

Codecov Report

❌ Patch coverage is 76.92308% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/cli/merge_sharded_fsdp_weights.py 15.38% 11 Missing ⚠️
src/axolotl/train.py 95.83% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
src/axolotl/cli/merge_sharded_fsdp_weights.py (3)

25-25: Optional: avoid heavy import from training module for a small utility

Importing determine_last_checkpoint from axolotl.train pulls in a large dependency surface. Consider moving determine_last_checkpoint to a small shared utility (e.g., axolotl.utils.checkpoints) to reduce CLI startup overhead and potential circular import risks.


199-207: Improve error message when no last checkpoint is found

If determine_last_checkpoint returns None, the error currently shows ... in None. Make the message explicit about which locations were checked.

-    if not fsdp_dir.exists():
-        raise ValueError(
-            f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
-        )
+    if not fsdp_dir.exists():
+        where = checkpoint_dir or parsed_cfg.output_dir
+        raise ValueError(
+            f"Could not find FSDP checkpoint directory 'pytorch_model_fsdp_0' under {where}"
+        )

214-222: Tighten wording and fix subject-verb agreement in logs

Minor grammar/clarity improvement.

-LOG.info(
-    f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
-    main_process_only=True,
-)
-LOG.info(
-    "Merged weights are only the safetensors and doesn't include the model configuration "
-    f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
-    main_process_only=True,
-)
+LOG.info(
+    f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
+    main_process_only=True,
+)
+LOG.info(
+    "Merged weights include only the tensor files and do not include the model configuration or tokenizer. "
+    f"Find those in {parsed_cfg.output_dir}.",
+    main_process_only=True,
+)
examples/gpt-oss/README.md (3)

43-49: Clarify phrasing and fix minor grammar

Use active voice and “into” vs “to”.

-When using SHARDED_STATE_DICT with FSDP, there is an additional post-training step to merge the sharded weights.
-This step will automatically determine the last checkpoint directory and merge the sharded weights to
-`{output_dir}/merged`.
+When using SHARDED_STATE_DICT with FSDP, there is an additional post-training step to merge the sharded weights.
+This step automatically determines the last checkpoint directory and merges the sharded weights into
+`{output_dir}/merged`.

51-56: Polish ERRATA section; avoid bare URL and fix capitalization

-ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config`.
-See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
+ERRATA: Transformers saves the model architecture name prefixed with `FSDP`, which must be manually corrected in `config.json`.
+See [huggingface/transformers#40207](https://github.com/huggingface/transformers/pull/40207) for the status of this issue.

58-69: Rename section and fix typos; avoid bare URLs

  • “Inferencing” -> “Inference”
  • Fix “infomation” -> “information”
  • Convert bare URLs to markdown links for linting and readability.
-### Inferencing your fine-tuned model
+### Inference with your fine-tuned model
@@
-GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
-for more information about using a special vllm-openai docker image for inferencing with vLLM.
+GPT-OSS support in vLLM does not exist in a stable release yet. See
+[this thread](https://x.com/MaziyarPanahi/status/1955741905515323425)
+for more information about using a special vllm-openai Docker image for inference with vLLM.
@@
-SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
-SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
+SGLang has zero-day support on main; see
+[sgl-project/sglang#8833](https://github.com/sgl-project/sglang/issues/8833)
+for information on installing SGLang from source. Once you've installed SGLang, run the following command
+to launch an SGLang server:
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 130ef7c and 30bef60.

📒 Files selected for processing (3)
  • examples/gpt-oss/README.md (1 hunks)
  • src/axolotl/cli/merge_sharded_fsdp_weights.py (2 hunks)
  • src/axolotl/train.py (2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/train.py (4)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
tests/e2e/multigpu/test_locking.py (1)
  • cfg (25-27)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/cli/merge_sharded_fsdp_weights.py (1)
src/axolotl/train.py (2)
  • train (537-585)
  • determine_last_checkpoint (127-155)
🪛 Ruff (0.12.2)
src/axolotl/train.py

149-150: Use a single if statement instead of nested if statements

(SIM102)

🪛 LanguageTool
examples/gpt-oss/README.md

[grammar] ~44-~44: There might be a mistake here.
Context: ...rectory and merge the sharded weights to {output_dir}/merged. ```bash axolotl m...

(QB_NEW_EN)


[grammar] ~51-~51: There might be a mistake here.
Context: ...eeds to be manually renamed in config. See https://github.com/huggingface/trans...

(QB_NEW_EN)


[grammar] ~60-~60: There might be a mistake here.
Context: ...MaziyarPanahi/status/1955741905515323425 for more information about using a speci...

(QB_NEW_EN)


[grammar] ~63-~63: Ensure spelling is correct
Context: ....com/sgl-project/sglang/issues/8833 for infomation on installing SGLang from source. Once ...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~63-~63: There might be a mistake here.
Context: ...issues/8833 for infomation on installing SGLang from source. Once you've installe...

(QB_NEW_EN)

🪛 markdownlint-cli2 (0.17.2)
examples/gpt-oss/README.md

52-52: Bare URL used

(MD034, no-bare-urls)


60-60: Bare URL used

(MD034, no-bare-urls)


63-63: Bare URL used

(MD034, no-bare-urls)

🔇 Additional comments (1)
src/axolotl/train.py (1)

571-573: LGTM on training flow change

Switching to determine_last_checkpoint(cfg) keeps the previous semantics and centralizes the logic. No concerns.

Comment thread src/axolotl/train.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
examples/gpt-oss/README.md (2)

43-48: Use a JSON-aware fix (jq) instead of sed; add macOS sed variant; minor wording/link polish

  • Replace global sed with a schema-aware update of architectures to avoid accidental edits elsewhere in the file.
  • Add macOS-compatible sed flags if jq is unavailable.
  • Lowercase “architecture” and avoid bare URLs.
-ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
-See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
+ERRATA: Transformers saves the model architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
+See [transformers#40207](https://github.com/huggingface/transformers/pull/40207) for the status of this issue.
 
-```bash
-sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
-```
+Prefer a JSON-aware rewrite (jq):
+```bash
+jq '.architectures = (.architectures | map(if .=="FSDPGptOssForCausalLM" then "GptOssForCausalLM" else . end))' \
+  ./outputs/gpt-oss-out/config.json > ./outputs/gpt-oss-out/config.json.tmp \
+  && mv ./outputs/gpt-oss-out/config.json.tmp ./outputs/gpt-oss-out/config.json
+```
+If jq is unavailable:
+- Linux (GNU sed):
+```bash
+sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
+```
+- macOS (BSD sed):
+```bash
+sed -i '' 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
+```

61-71: Fix typo and wrap bare URLs; tweak section title for clarity; add a small TP note

  • “infomation” → “information”.
  • Avoid bare URLs and make link text self-descriptive.
  • “Inferencing” → “Inference” is more standard.
  • Optional: day-zero phrasing reads better than “0-day.”
-### Inferencing your fine-tuned model
+### Inference with your fine-tuned model
 
-GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
-for more information about using a special vllm-openai docker image for inferencing with vLLM.
+GPT-OSS support in vLLM does not exist in a stable release yet. See
+[this note from @MaziyarPanahi](https://x.com/MaziyarPanahi/status/1955741905515323425)
+for more information about using a special vllm-openai Docker image for inference with vLLM.
 
-SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
-SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
+SGLang has day-zero support in main. See
+[sglang/issues/8833](https://github.com/sgl-project/sglang/issues/8833) for information on installing
+SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
 
 ```bash
 python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8

+Ensure that --tp matches the number of tensor-parallel shards you intend to use on your inference hardware.


</blockquote></details>

</blockquote></details>

<details>
<summary>📜 Review details</summary>

**Configuration used: .coderabbit.yaml**
**Review profile: CHILL**
**Plan: Pro**

**💡 Knowledge Base configuration:**

- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

<details>
<summary>📥 Commits</summary>

Reviewing files that changed from the base of the PR and between 30bef60fc403d4b6187515e1ceae75bd67cb7df4 and 7707b0973330684da071d9d1322164152dff8699.

</details>

<details>
<summary>📒 Files selected for processing (1)</summary>

* `examples/gpt-oss/README.md` (1 hunks)

</details>

<details>
<summary>🧰 Additional context used</summary>

<details>
<summary>🪛 LanguageTool</summary>

<details>
<summary>examples/gpt-oss/README.md</summary>

[grammar] ~43-~43: There might be a mistake here.
Context: ...to be manually renamed in `config.json`. See https://github.com/huggingface/trans...

(QB_NEW_EN)

---

[grammar] ~52-~52: There might be a mistake here.
Context: ...eckpoint directory and merge the sharded weights to `{output_dir}/merged`.  ```ba...

(QB_NEW_EN)

---

[grammar] ~63-~63: There might be a mistake here.
Context: ...MaziyarPanahi/status/1955741905515323425 for more information about using a speci...

(QB_NEW_EN)

---

[grammar] ~66-~66: Ensure spelling is correct
Context: ....com/sgl-project/sglang/issues/8833 for infomation on installing SGLang from source. Once ...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

---

[grammar] ~66-~66: There might be a mistake here.
Context: ...issues/8833 for infomation on installing SGLang from source. Once you've installe...

(QB_NEW_EN)

</details>

</details>
<details>
<summary>🪛 markdownlint-cli2 (0.17.2)</summary>

<details>
<summary>examples/gpt-oss/README.md</summary>

44-44: Bare URL used

(MD034, no-bare-urls)

---

63-63: Bare URL used

(MD034, no-bare-urls)

---

66-66: Bare URL used

(MD034, no-bare-urls)

</details>

</details>

</details>

<details>
<summary>⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)</summary>

* GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
* GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
* GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
* GitHub Check: PyTest (3.11, 2.7.1)
* GitHub Check: PyTest (3.11, 2.6.0)
* GitHub Check: PyTest (3.11, 2.7.0)
* GitHub Check: pre-commit
* GitHub Check: preview
* GitHub Check: pre-commit

</details>

</details>

<!-- This is an auto-generated comment by CodeRabbit for review status -->

Comment on lines +50 to +59
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
weights to `{output_dir}/merged`.

```bash
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
```

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Clarify disk space needs, fix minor grammar, and use a safer move (rsync) after merge

  • Single-space after periods.
  • Explicitly call out that the merged dir contains only tensor weights (no config/tokenizer).
  • Recommend rsync to preserve permissions and handle dotfiles; optionally clean up the merged dir.
-merge the sharded weights.  This step will automatically determine the last checkpoint directory and merge the sharded
+merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
 weights to `{output_dir}/merged`.
 
 ```bash
 axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
-mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
+rsync -a ./outputs/gpt-oss-out/merged/ ./outputs/gpt-oss-out/ && find ./outputs/gpt-oss-out/merged -type f -delete
+# Optional: remove now-empty merged dir
+rmdir ./outputs/gpt-oss-out/merged 2>/dev/null || true

+Note:
+- Ensure sufficient free disk space for the merge (typically at least 1.2–1.5× the final model size).
+- {output_dir}/merged contains only weight tensor files (e.g., .safetensors). It does not include config.json or the tokenizer—retain those from output_dir or the last checkpoint.


To confirm the CLI and checkpoint discovery wiring referenced here exist and import correctly, run:

---

🏁 Script executed:

```shell
#!/bin/bash
# Verify CLI entrypoint and import relationships
set -euo pipefail
fd -t f merge_sharded_fsdp_weights.py
rg -n -C3 -i 'determine_last_checkpoint' src/axolotl/cli/merge_sharded_fsdp_weights.py
rg -n -C3 -i 'def determine_last_checkpoint\(|from .* import determine_last_checkpoint' src/axolotl/train.py src/axolotl/cli/merge_sharded_fsdp_weights.py

Length of output: 1955


🏁 Script executed:

#!/bin/bash
set -euo pipefail
echo "Path:"
fd -t f merge_sharded_fsdp_weights.py
echo
echo "File snippet (1-400):"
nl -ba src/axolotl/cli/merge_sharded_fsdp_weights.py | sed -n '1,400p'
echo
echo "Search for write/save keywords:"
rg -n -C3 -i 'save_pretrained|safetensor|safetensors|state_dict|torch.save|save' src/axolotl/cli/merge_sharded_fsdp_weights.py || true

Length of output: 398


🏁 Script executed:

#!/bin/bash
set -euo pipefail
echo "=== merge_sharded_fsdp_weights.py (first 400 lines) ==="
sed -n '1,400p' src/axolotl/cli/merge_sharded_fsdp_weights.py || true
echo
echo "=== grep for relevant keywords across repo (context) ==="
rg -n -C3 -i 'merged|pytorch_model_fsdp|safetensors|.safetensors|config.json|tokenizer|save_pretrained|torch.save|state_dict|merge-sharded-fsdp-weights' || true
echo
echo "=== determine_last_checkpoint in train.py (around def) ==="
sed -n '120,160p' src/axolotl/train.py || true

Length of output: 262936


Docs: clarify FSDP merge instructions & use a safer move

Verified that the CLI and checkpoint-discovery wiring exist:

  • src/axolotl/cli/merge_sharded_fsdp_weights.py is present and imports determine_last_checkpoint from src/axolotl/train.py (def determine_last_checkpoint found in train.py).
  • The CLI writes merged weights to {output_dir}/merged and already logs that merged weights are safetensors only (config/tokenizer are not moved).

Please apply this small README change (examples/gpt-oss/README.md — lines ~50–59):

-When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
-configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
-merge the sharded weights.  This step will automatically determine the last checkpoint directory and merge the sharded
-weights to `{output_dir}/merged`.
-
-```bash
-axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
-mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
-```
+When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
+configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
+merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
+weights to `{output_dir}/merged`.
+
+```bash
+axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
+# Copy merged weights into output dir (preserves permissions, handles dotfiles) and cleanup merged files.
+rsync -a ./outputs/gpt-oss-out/merged/ ./outputs/gpt-oss-out/ && find ./outputs/gpt-oss-out/merged -type f -delete
+# Optional: remove now-empty merged dir
+rmdir ./outputs/gpt-oss-out/merged 2>/dev/null || true
+```
+
+Note:
+- Ensure sufficient free disk space for the merge (a conservative guideline is ~1.2–1.5× the final model size).
+- `{output_dir}/merged` contains only the merged weight tensor files and an index (e.g., safetensors shards and index). It does not include `config.json` or tokenizer files — retain those from `output_dir` or the last checkpoint.

Files to review/update:

  • examples/gpt-oss/README.md (lines ~50–59)
  • (informational) src/axolotl/cli/merge_sharded_fsdp_weights.py — confirms behavior and existing log message about safetensors-only output.
🧰 Tools
🪛 LanguageTool

[grammar] ~52-~52: There might be a mistake here.
Context: ...eckpoint directory and merge the sharded weights to {output_dir}/merged. ```ba...

(QB_NEW_EN)

🤖 Prompt for AI Agents
In examples/gpt-oss/README.md around lines 50 to 59, update the example and
guidance for merging FSDP sharded weights: replace the unsafe mv command with a
safer rsync-based sequence that copies merged files into the output dir while
preserving permissions and dotfiles, then removes the merged files and
optionally removes the empty merged directory; also add the note about ensuring
sufficient free disk space (recommend ~1.2–1.5× final model size) and clarify
that `{output_dir}/merged` contains only merged weight tensor files (safetensors
shards and index) and does not include config or tokenizer files which must be
retained from the output dir or last checkpoint.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
src/axolotl/core/builders/rl.py (1)

218-224: Fix: Pass TrainingArguments Positionally for GRPO Trainers to Avoid Duplicate ‘args’

The GRPO trainer’s __init__ signature is:

def __init__(self, model, reward_funcs, args, train_dataset, …)

Passing training_args both in *trainer_cls_args (positional) and as args=training_args causes:

TypeError: Trainer.__init__() got multiple values for argument 'args'

• File needing update:
src/axolotl/core/builders/rl.py (around lines 218–224)

Apply this diff:

-        trainer = trainer_cls(
-            *trainer_cls_args,
-            args=training_args,
-            train_dataset=self.train_dataset,
-            callbacks=self.get_callbacks(),
-            **trainer_kwargs,
-        )
+        if self.cfg.rl is RLType.GRPO:
+            # GRPO trainers expect `args` positionally; passing it as kwarg collides
+            trainer = trainer_cls(
+                *[*trainer_cls_args, training_args],
+                train_dataset=self.train_dataset,
+                callbacks=self.get_callbacks(),
+                **trainer_kwargs,
+            )
+        else:
+            trainer = trainer_cls(
+                *trainer_cls_args,
+                args=training_args,
+                train_dataset=self.train_dataset,
+                callbacks=self.get_callbacks(),
+                **trainer_kwargs,
+            )

This ensures training_args is supplied exactly once when initializing GRPO trainers.

♻️ Duplicate comments (2)
src/axolotl/train.py (2)

140-149: Make checkpoint discovery robust: avoid fragile string-split on full paths

Sorting by int(path.split("-")[-1]) on full path strings is brittle and can raise if the suffix isn’t numeric (e.g., checkpoint-last, or hyphens elsewhere in the path). Prefer Path.name, filter numeric suffixes, and reuse the computed last_checkpoint instead of re-indexing sorted_paths again.

Apply this diff to the body to fix both issues:

-    last_checkpoint = None
-    possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
-    if len(possible_checkpoints) > 0:
-        sorted_paths = sorted(
-            possible_checkpoints,
-            key=lambda path: int(path.split("-")[-1]),
-        )
-        if not update:
-            return sorted_paths[-1]
-        last_checkpoint = sorted_paths[-1]
+    last_checkpoint = None
+    checkpoints = sorted(
+        (p for p in Path(cfg.output_dir).glob("checkpoint-*") if p.name.split("-")[-1].isdigit()),
+        key=lambda p: int(p.name.split("-")[-1]),
+    )
+    if checkpoints:
+        last_checkpoint = str(checkpoints[-1])
+        if not update:
+            return last_checkpoint

151-157: Combine nested conditions and reuse computed last_checkpoint

  • SIM102: Combine nested ifs to simplify flow.
  • Avoid referencing sorted_paths again; rely on last_checkpoint already computed.

Apply this diff:

-    if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
-        if last_checkpoint is not None:
-            cfg.resume_from_checkpoint = sorted_paths[-1]
-            LOG.info(
-                f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
-            )
+    if (
+        cfg.resume_from_checkpoint is None
+        and cfg.auto_resume_from_checkpoints
+        and last_checkpoint is not None
+    ):
+        cfg.resume_from_checkpoint = last_checkpoint
+        LOG.info(
+            f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
+        )
🧹 Nitpick comments (3)
src/axolotl/core/builders/rl.py (2)

164-167: Tiny cleanup: simplify blocklist removals

Use pop with a default to avoid the membership check.

-        for blocklist_key in blocklist_args_kwargs:
-            if blocklist_key in training_args_kwargs:
-                del training_args_kwargs[blocklist_key]
+        for blocklist_key in blocklist_args_kwargs:
+            training_args_kwargs.pop(blocklist_key, None)

207-209: Avoid re-invoking the DPO trainer-class factory in the identity check

Calling DPOStrategy.get_trainer_class() again inside the condition is unnecessary. Cache it first; improves readability and avoids surprises if the factory becomes dynamic.

-        if self.cfg.datasets is not None and (
-            trainer_cls is DPOStrategy.get_trainer_class()
-        ):
+        dpo_trainer_cls = DPOStrategy.get_trainer_class()
+        if self.cfg.datasets is not None and trainer_cls is dpo_trainer_cls:
src/axolotl/train.py (1)

7-7: Imports look fine; minor cleanup opportunities

  • Using json + OrderedDict is fine to preserve key order, though in Python 3.7+ dicts preserve insertion order by default. Keeping OrderedDict is acceptable if you want to be explicit.
  • Consider using pathlib.Path consistently for file IO later (see comment on Lines 298-312).

Also applies to: 13-13

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 7707b09 and bf56beb.

📒 Files selected for processing (3)
  • src/axolotl/core/builders/rl.py (2 hunks)
  • src/axolotl/core/trainers/__init__.py (0 hunks)
  • src/axolotl/train.py (5 hunks)
💤 Files with no reviewable changes (1)
  • src/axolotl/core/trainers/init.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/core/builders/rl.py (1)
src/axolotl/core/trainers/grpo/__init__.py (1)
  • GRPOStrategy (25-204)
src/axolotl/train.py (2)
src/axolotl/core/builders/rl.py (1)
  • HFRLTrainerBuilder (24-234)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
🪛 GitHub Actions: Tests
src/axolotl/core/builders/rl.py

[error] 218-218: GRPO training build failed: TypeError: Trainer.init() got multiple values for argument 'args' during AxolotlGRPOTrainer initialization (likely due to a deprecation wrapper in transformers mapping the deprecated 'args' arg incorrectly).

🪛 Ruff (0.12.2)
src/axolotl/train.py

151-152: Use a single if statement instead of nested if statements

(SIM102)

🔇 Additional comments (4)
src/axolotl/core/builders/rl.py (2)

55-63: Lazy import of GRPOStrategy is a good move

Deferring the GRPO import to the branch reduces import-time overhead and helps avoid circular imports. Usage of GRPOStrategy immediately after the import is correct.


151-156: Consistent lazy import of GRPOStrategy in training-args path

Mirroring the lazy import here keeps the module load light and avoids unnecessary GRPO dependencies when not used. Looks good.

src/axolotl/train.py (2)

51-51: Good: updated TYPE_CHECKING import path to core.builders

This aligns with the module reorg and keeps runtime footprint minimal.


589-589: Usage of determine_last_checkpoint verified
All call sites correctly handle the update flag:

  • In src/axolotl/train.py (line 589), the default update=True is used to mutate cfg.resume_from_checkpoint as intended for auto-resume.
  • In src/axolotl/cli/merge_sharded_fsdp_weights.py (line 200), update=False is explicitly passed to avoid mutation when only a lookup is needed.

No other callers found—no changes required.

Comment thread src/axolotl/train.py Outdated
Comment on lines +298 to +312
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
with open(
os.path.join(cfg.output_dir, "config.json"), "r", encoding="utf-8"
) as f:
# read the model config as an OrderedDict
config = json.load(f, object_pairs_hook=OrderedDict)
config["architectures"] = [
name.lstrip("FSDP") for name in config["architectures"]
]
# write the updated model config back
with open(
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
) as f:
json.dump(config, f, indent=2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: lstrip("FSDP") can corrupt names; use removeprefix/startswith and guard missing key; also prefer Path for IO

  • str.lstrip("FSDP") removes any combination of F/S/D/P from the start, not the literal prefix. Example: "SPegasus" becomes "egasus". This risks silently mangling architecture names.
  • Guard when "architectures" is absent or not a list to avoid KeyError.
  • Minor: use Path for cleaner IO.

Proposed fix:

-        # TODO(wing):see https://github.com/huggingface/transformers/pull/40207
-        # cleanup the FSDP prefix in the model config.json
-        with open(
-            os.path.join(cfg.output_dir, "config.json"), "r", encoding="utf-8"
-        ) as f:
-            # read the model config as an OrderedDict
-            config = json.load(f, object_pairs_hook=OrderedDict)
-            config["architectures"] = [
-                name.lstrip("FSDP") for name in config["architectures"]
-            ]
-        # write the updated model config back
-        with open(
-            os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
-        ) as f:
-            json.dump(config, f, indent=2)
+        # TODO(wing): see https://github.com/huggingface/transformers/pull/40207
+        # cleanup the FSDP prefix in the model config.json
+        cfg_path = Path(cfg.output_dir) / "config.json"
+        with cfg_path.open("r", encoding="utf-8") as f:
+            # read the model config as an OrderedDict
+            config = json.load(f, object_pairs_hook=OrderedDict)
+            archs = config.get("architectures")
+            if isinstance(archs, list):
+                # remove literal 'FSDP' prefix only
+                config["architectures"] = [
+                    a[4:] if isinstance(a, str) and a.startswith("FSDP") else a
+                    for a in archs
+                ]
+        # write the updated model config back
+        with cfg_path.open("w", encoding="utf-8") as f:
+            json.dump(config, f, indent=2)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
with open(
os.path.join(cfg.output_dir, "config.json"), "r", encoding="utf-8"
) as f:
# read the model config as an OrderedDict
config = json.load(f, object_pairs_hook=OrderedDict)
config["architectures"] = [
name.lstrip("FSDP") for name in config["architectures"]
]
# write the updated model config back
with open(
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
) as f:
json.dump(config, f, indent=2)
# TODO(wing): see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
cfg_path = Path(cfg.output_dir) / "config.json"
with cfg_path.open("r", encoding="utf-8") as f:
# read the model config as an OrderedDict
config = json.load(f, object_pairs_hook=OrderedDict)
archs = config.get("architectures")
if isinstance(archs, list):
# remove literal 'FSDP' prefix only
config["architectures"] = [
a[4:] if isinstance(a, str) and a.startswith("FSDP") else a
for a in archs
]
# write the updated model config back
with cfg_path.open("w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
🤖 Prompt for AI Agents
In src/axolotl/train.py around lines 298 to 312, the code uses
name.lstrip("FSDP") which can corrupt architecture names and doesn't guard if
"architectures" is missing or not a list; replace that logic to: read config
using pathlib.Path for IO, check if "architectures" exists and is a list before
modifying, and for each name use a safe prefix-removal (e.g., if
isinstance(name, str) and name.startswith("FSDP"): name =
name.removeprefix("FSDP")) so only the literal "FSDP" prefix is removed; then
write the JSON back via Path.write_text/json.dump as before.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
src/axolotl/train.py (1)

286-295: Fix potential UnboundLocalError for state_dict_type when FSDP is enabled but cfg.fsdp/fsdp_config are unset

state_dict_type is only assigned inside the inner if (cfg.fsdp_config or cfg.fsdp). If trainer.is_fsdp_enabled is True while both config flags are falsy, the later check if state_dict_type == "SHARDED_STATE_DICT" will reference an uninitialized local.

Initialize state_dict_type at the start of the outer block:

-    if trainer.is_fsdp_enabled or cfg.fsdp_config:
-        if cfg.fsdp_config or cfg.fsdp:
+    if trainer.is_fsdp_enabled or cfg.fsdp_config:
+        state_dict_type: str | None = None
+        if cfg.fsdp_config or cfg.fsdp:
             if cfg.fsdp_config.final_state_dict_type:
                 state_dict_type = cfg.fsdp_config.final_state_dict_type
             else:
                 state_dict_type = cfg.fsdp_config.state_dict_type
             trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
         trainer.save_model(cfg.output_dir)  # only handles FULL_STATE_DICT
-        if state_dict_type == "SHARDED_STATE_DICT":
+        if state_dict_type == "SHARDED_STATE_DICT":
♻️ Duplicate comments (2)
src/axolotl/train.py (2)

141-158: Make last-checkpoint discovery robust; avoid fragile string-split and collapse nested if

Current approach splits the full path on '-' and casts to int. This can break if non-numeric suffixes appear or if parent directories contain dashes. Also, the nested if triggers SIM102. Filter by numeric suffix using Path.name and reuse last_checkpoint directly.

Proposed patch:

-    last_checkpoint = None
-    possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
-    if len(possible_checkpoints) > 0:
-        sorted_paths = sorted(
-            possible_checkpoints,
-            key=lambda path: int(path.split("-")[-1]),
-        )
-        if not update:
-            return sorted_paths[-1]
-        last_checkpoint = sorted_paths[-1]
+    last_checkpoint = None
+    checkpoints = sorted(
+        (p for p in Path(cfg.output_dir).glob("checkpoint-*") if p.name.split("-")[-1].isdigit()),
+        key=lambda p: int(p.name.split("-")[-1]),
+    )
+    if checkpoints:
+        last_checkpoint = str(checkpoints[-1])
+        if not update:
+            return last_checkpoint
@@
-    if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
-        if last_checkpoint is not None:
-            cfg.resume_from_checkpoint = sorted_paths[-1]
-            LOG.info(
-                f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
-            )
+    if (
+        cfg.resume_from_checkpoint is None
+        and cfg.auto_resume_from_checkpoints
+        and last_checkpoint is not None
+    ):
+        cfg.resume_from_checkpoint = last_checkpoint
+        LOG.info(
+            f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
+        )

321-334: Bug: lstrip("FSDP") can corrupt architecture names; guard key and use literal prefix removal

str.lstrip("FSDP") removes any combination of F/S/D/P at the start, not the literal "FSDP" prefix. Also guard when "architectures" is absent or not a list, and use Path I/O consistently.

Apply:

-        if trainer.accelerator.is_main_process:
-            with open(
-                Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
-            ) as config_file_io:
-                # read the model config as an OrderedDict
-                config = json.load(config_file_io, object_pairs_hook=OrderedDict)
-                config["architectures"] = [
-                    name.lstrip("FSDP") for name in config["architectures"]
-                ]
-            # write the updated model config back
-            with open(
-                os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
-            ) as config_file_io:
-                json.dump(config, config_file_io, indent=2)
+        if trainer.accelerator.is_main_process:
+            cfg_path = Path(cfg.output_dir) / "config.json"
+            with cfg_path.open("r", encoding="utf-8") as f:
+                # read the model config as an OrderedDict
+                config = json.load(f, object_pairs_hook=OrderedDict)
+                archs = config.get("architectures")
+                if isinstance(archs, list):
+                    # remove literal 'FSDP' prefix only
+                    config["architectures"] = [
+                        a[4:] if isinstance(a, str) and a.startswith("FSDP") else a
+                        for a in archs
+                    ]
+            # write the updated model config back
+            with cfg_path.open("w", encoding="utf-8") as f:
+                json.dump(config, f, indent=2)
🧹 Nitpick comments (1)
src/axolotl/train.py (1)

299-318: Auto-merge of SHARDED_STATE_DICT FSDP weights looks good

  • Gated to main process and only when an index file isn’t present.
  • Uses the existing CLI merge routine, then moves artifacts into output_dir and cleans up the temp dir.

Optional: consider exposing/remove_checkpoint_dir=True to reclaim space in large runs once the merged artifact is produced (or make it configurable).

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8e6ff02 and 33dc7ae.

📒 Files selected for processing (1)
  • src/axolotl/train.py (5 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/train.py (5)
src/axolotl/core/builders/causal.py (1)
  • HFCausalTrainerBuilder (52-504)
src/axolotl/core/builders/rl.py (1)
  • HFRLTrainerBuilder (24-234)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/distributed.py (1)
  • is_main_process (83-98)
src/axolotl/cli/merge_sharded_fsdp_weights.py (1)
  • merge_fsdp_weights (120-184)
🪛 Ruff (0.12.2)
src/axolotl/train.py

152-153: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (3)
src/axolotl/train.py (3)

7-7: New imports look appropriate

json/shutil/OrderedDict are used below; no issues.

Also applies to: 9-9, 14-14


52-52: TYPE_CHECKING import path update looks good

Matches the new module layout under axolotl.core.builders.


611-611: All call sites successfully migrated; update=False used in discovery contexts

  • No remaining references to determine_resume_checkpoint found.
  • Verified determine_last_checkpoint(cfg, update=False) is used in non-mutating contexts at:
    • src/axolotl/train.py:299
    • src/axolotl/cli/merge_sharded_fsdp_weights.py:200
  • Default determine_last_checkpoint(cfg) (update=True) on line 611 appropriately updates the config for actual training resume.

@winglian
Copy link
Copy Markdown
Collaborator Author

winglian commented Aug 15, 2025

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
src/axolotl/train.py (1)

256-266: Potential UnboundLocalError: state_dict_type may be undefined

If trainer.is_fsdp_enabled is True but neither cfg.fsdp_config nor cfg.fsdp are set, state_dict_type is not defined before being checked, which can raise at runtime.

Apply:

-    if trainer.is_fsdp_enabled or cfg.fsdp_config:
-        if cfg.fsdp_config or cfg.fsdp:
-            if cfg.fsdp_config.final_state_dict_type:
-                state_dict_type = cfg.fsdp_config.final_state_dict_type
-            else:
-                state_dict_type = cfg.fsdp_config.state_dict_type
-            trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
-        trainer.save_model(cfg.output_dir)  # only handles FULL_STATE_DICT
+    if trainer.is_fsdp_enabled or cfg.fsdp_config:
+        state_dict_type: str | None = None
+        if cfg.fsdp_config or cfg.fsdp:
+            if cfg.fsdp_config.final_state_dict_type:
+                state_dict_type = cfg.fsdp_config.final_state_dict_type
+            else:
+                state_dict_type = cfg.fsdp_config.state_dict_type
+            trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
+        trainer.save_model(cfg.output_dir)  # only handles FULL_STATE_DICT
♻️ Duplicate comments (1)
src/axolotl/train.py (1)

293-305: Use safe prefix removal and guard config fields when cleaning architectures; also prefer Path IO

str.lstrip("FSDP") removes any combination of F/S/D/P and may corrupt names (e.g., "SPegasus" -> "egasus"). Guard missing/invalid architectures and remove the literal prefix only.

-        if trainer.accelerator.is_main_process:
-            with open(
-                Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
-            ) as config_file_io:
-                # read the model config as an OrderedDict
-                config = json.load(config_file_io, object_pairs_hook=OrderedDict)
-                config["architectures"] = [
-                    name.lstrip("FSDP") for name in config["architectures"]
-                ]
-            # write the updated model config back
-            with open(
-                os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
-            ) as config_file_io:
-                json.dump(config, config_file_io, indent=2)
+        if trainer.accelerator.is_main_process:
+            cfg_path = Path(cfg.output_dir) / "config.json"
+            if cfg_path.exists():
+                with cfg_path.open("r", encoding="utf-8") as f:
+                    config = json.load(f, object_pairs_hook=OrderedDict)
+                archs = config.get("architectures")
+                if isinstance(archs, list):
+                    config["architectures"] = [
+                        a[4:] if isinstance(a, str) and a.startswith("FSDP") else a
+                        for a in archs
+                    ]
+                with cfg_path.open("w", encoding="utf-8") as f:
+                    json.dump(config, f, indent=2)
🧹 Nitpick comments (8)
src/axolotl/cli/merge_sharded_fsdp_weights.py (4)

61-61: Ensure merged output directory creation handles missing parents

Use parents=True to avoid failures if an intermediate directory is missing.

-    save_path_.mkdir(exist_ok=True)
+    save_path_.mkdir(parents=True, exist_ok=True)

149-151: Fix typo in error message

Stray backtick at the end of the version requirement message.

-        raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
+        raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0")

152-171: Validate both FSDP subdir paths and parent checkpoint directories

merge_fsdp_weights is called with checkpoint_dir pointing to .../pytorch_model_fsdp_0 (from do_cli), but the current error helper assumes it might be the parent checkpoint dir. As written, checking checkpoint_dir_/pytorch_model_fsdp_0 when checkpoint_dir_ already ends with pytorch_model_fsdp_0 yields a misleading suggestion. This refactor cleanly supports both forms.

-    # Verify that the checkpoint directory exists
-    if not checkpoint_dir_.exists():
-        model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
-        optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
-        err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file."
-        if model_path_exists and optimizer_path_exists:
-            err += (
-                " However, potential model and optimizer checkpoint directories exist."
-            )
-            err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0"
-            err += "instead."
-        elif model_path_exists:
-            err += " However, a potential model checkpoint directory exists."
-            err += (
-                f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead."
-            )
-        elif optimizer_path_exists:
-            err += " However, a potential optimizer checkpoint directory exists."
-            err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead."
-        raise ValueError(err)
+    # Verify that the checkpoint directory exists and is in a supported form.
+    # We accept either:
+    # - a direct FSDP subdir: .../pytorch_model_fsdp_0 or .../optimizer_0
+    # - a parent checkpoint dir containing those subdirs
+    if not checkpoint_dir_.exists():
+        raise ValueError(f"Tried to load from {checkpoint_dir_} but it does not exist.")
+    if checkpoint_dir_.name not in {"pytorch_model_fsdp_0", "optimizer_0"}:
+        model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
+        optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
+        if not (model_path_exists or optimizer_path_exists):
+            raise ValueError(
+                "Could not locate FSDP subdirectories under "
+                f"{checkpoint_dir_}. Expected one of: "
+                f"{checkpoint_dir_ / 'pytorch_model_fsdp_0'} or "
+                f"{checkpoint_dir_ / 'optimizer_0'}."
+            )

198-205: Edge case: clearer error when no checkpoints are found

If determine_last_checkpoint returns None, the error currently references checkpoint_dir=None. Small copy tweak improves the message.

-        if not fsdp_dir.exists():
-            raise ValueError(
-                f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
-            )
+        if not fsdp_dir.exists():
+            where = checkpoint_dir or parsed_cfg.output_dir
+            raise ValueError(
+                f"Could not find FSDP checkpoint directory `pytorch_model_fsdp_0` under {where}"
+            )
examples/gpt-oss/README.md (4)

36-38: Tighten wording on disk-space guidance

Minor grammar cleanup improves readability.

-On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
-model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
+On 8xH100s, plan for ~3TB of free disk space. With each checkpoint ~720GB, plus the base model and the final merged output, you may need at least 3TB to retain two checkpoints.

44-49: Docs may be outdated: config.json FSDP prefix cleanup is now automatic

The training code now removes the FSDP prefix from architectures in config.json on the main process. This manual sed step should no longer be necessary for current versions.

-ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
-See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
-
-```bash
-sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
-```
+Note: Transformers may save the model architecture prefixed with `FSDP`, but Axolotl now cleans this automatically after training. No manual change to `config.json` is required in recent versions. See [transformers PR #40207](https://github.com/huggingface/transformers/pull/40207) for upstream status.

51-59: Use a safer copy for merged weights and fix extra spacing

mv can miss dotfiles and doesn't preserve permissions. rsync is safer, and we can clean up after.

-When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
-configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
-merge the sharded weights.  This step will automatically determine the last checkpoint directory and merge the sharded
-weights to `{output_dir}/merged`.
+When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
+configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
+merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
+weights to `{output_dir}/merged`.
 
 ```bash
 axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
-mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
+rsync -a ./outputs/gpt-oss-out/merged/ ./outputs/gpt-oss-out/ && find ./outputs/gpt-oss-out/merged -type f -delete
+# Optional: remove now-empty dir
+rmdir ./outputs/gpt-oss-out/merged 2>/dev/null || true

---

`64-68`: **Minor grammar and spelling fixes; avoid bare URLs**

Fix typo and use proper link formatting.



```diff
-GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
-for more information about using a special vllm-openai docker image for inferencing with vLLM.
+GPT-OSS support in vLLM does not exist in a stable release yet. See this [post](https://x.com/MaziyarPanahi/status/1955741905515323425)
+for details about using a special vllm-openai Docker image for inference with vLLM.
@@
-SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
-SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
+SGLang has 0-day support in main; see [sglang/issues/8833](https://github.com/sgl-project/sglang/issues/8833) for information on installing
+SGLang from source. Once installed, run the following command to launch an SGLang server:
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 33dc7ae and 84fba12.

📒 Files selected for processing (5)
  • examples/gpt-oss/README.md (1 hunks)
  • examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml (1 hunks)
  • src/axolotl/cli/merge_sharded_fsdp_weights.py (3 hunks)
  • src/axolotl/train.py (5 hunks)
  • src/axolotl/utils/train.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
src/axolotl/utils/train.py (2)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/train.py (3)
src/axolotl/cli/main.py (3)
  • train (99-137)
  • cli (45-49)
  • merge_sharded_fsdp_weights (246-275)
src/axolotl/utils/train.py (1)
  • determine_last_checkpoint (11-39)
src/axolotl/cli/merge_sharded_fsdp_weights.py (1)
  • merge_fsdp_weights (121-183)
src/axolotl/cli/merge_sharded_fsdp_weights.py (2)
src/axolotl/train.py (1)
  • train (547-595)
src/axolotl/utils/train.py (1)
  • determine_last_checkpoint (11-39)
🪛 Ruff (0.12.2)
src/axolotl/utils/train.py

33-34: Use a single if statement instead of nested if statements

(SIM102)

🪛 LanguageTool
examples/gpt-oss/README.md

[grammar] ~36-~36: There might be a mistake here.
Context: ...ocking in at ~720GB, along with the base model, and final model output, you may n...

(QB_NEW_EN)


[grammar] ~44-~44: There might be a mistake here.
Context: ...to be manually renamed in config.json. See https://github.com/huggingface/trans...

(QB_NEW_EN)


[grammar] ~53-~53: There might be a mistake here.
Context: ...eckpoint directory and merge the sharded weights to {output_dir}/merged. ```ba...

(QB_NEW_EN)


[grammar] ~64-~64: There might be a mistake here.
Context: ...MaziyarPanahi/status/1955741905515323425 for more information about using a speci...

(QB_NEW_EN)


[grammar] ~67-~67: Ensure spelling is correct
Context: ....com/sgl-project/sglang/issues/8833 for infomation on installing SGLang from source. Once ...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~67-~67: There might be a mistake here.
Context: ...issues/8833 for infomation on installing SGLang from source. Once you've installe...

(QB_NEW_EN)

🪛 markdownlint-cli2 (0.17.2)
examples/gpt-oss/README.md

45-45: Bare URL used

(MD034, no-bare-urls)


64-64: Bare URL used

(MD034, no-bare-urls)


67-67: Bare URL used

(MD034, no-bare-urls)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (3)
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml (1)

23-23: Good addition to manage disk usage during massive checkpoints

Adding save_total_limit: 2 aligns with the documented disk-space constraints for 120B and will help prevent storage exhaustion.

src/axolotl/cli/merge_sharded_fsdp_weights.py (1)

213-223: Nice synchronized logging after merge

Barrier then a concise, main-process-only confirmation plus a clear note about safetensors-only output is solid UX.

src/axolotl/train.py (1)

269-290: Solid automatic FSDP SHARDED_STATE_DICT merge with safe-serialization and cleanup

Automatically merging shards into {output_dir}/merged, synchronizing, then moving into output_dir and deleting the temp dir is a great DX improvement.

Comment thread src/axolotl/utils/train.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
tests/utils/test_train.py (5)

13-17: Also test ignoring non-numeric checkpoint directories

Add a non-numeric checkpoint directory (e.g., checkpoint-best) to verify the implementation ignores it. This guards against regressions in the filter logic.

Apply this diff:

     for cpt_idx in [1, 9, 10, 20]:
         os.makedirs(
             os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
         )
+
+    # Non-numeric checkpoint name should be ignored by determine_last_checkpoint
+    os.makedirs(os.path.join(cfg.output_dir, "checkpoint-best"), exist_ok=True)

18-20: Assert the function does not mutate cfg when update=False

Add an explicit assertion that resume_from_checkpoint remains None after the first call, documenting the non-mutating contract.

Apply this diff:

     last_checkpoint = determine_last_checkpoint(cfg, update=False)
     assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
+
+    # Ensure cfg was not mutated when update=False
+    assert cfg.resume_from_checkpoint is None

21-24: Capture and assert the return value when update=True

The function returns the resolved checkpoint when update=True in this scenario. Capturing and asserting it makes the test verify both side-effect and return value.

Apply this diff:

-    determine_last_checkpoint(cfg, update=True)
-    assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
+    returned = determine_last_checkpoint(cfg, update=True)
+    assert returned == os.path.join(cfg.output_dir, "checkpoint-20")
+    assert cfg.resume_from_checkpoint == returned

5-6: Minor robustness improvement in determine_last_checkpoint (filter only directories)

Currently, Path.glob("checkpoint-*") may include files. Consider filtering to directories in src/axolotl/utils/train.py to avoid false positives.

You can update the generator expression like this:

# in src/axolotl/utils/train.py
checkpoints = sorted(
    (
        p
        for p in Path(cfg.output_dir).glob("checkpoint-*")
        if p.is_dir() and p.name.split("-")[-1].isdigit()
    ),
    key=lambda p: int(p.name.split("-")[-1]),
)

9-24: Add a couple more tests for edge cases

Consider adding tests for:

  • No checkpoints present: returns None and does not update cfg.
  • resume_from_checkpoint already set: not overridden even if auto_resume_from_checkpoints is True.

Example tests you can add to this module:

def test_determine_last_checkpoint_no_checkpoints(temp_dir):
    cfg = DictDefault(output_dir=temp_dir)
    cfg.resume_from_checkpoint = None
    cfg.auto_resume_from_checkpoints = True

    assert determine_last_checkpoint(cfg, update=False) is None
    determine_last_checkpoint(cfg, update=True)
    assert cfg.resume_from_checkpoint is None

def test_determine_last_checkpoint_preserves_existing_resume(temp_dir):
    cfg = DictDefault(output_dir=temp_dir)
    # Create a couple of checkpoints
    os.makedirs(os.path.join(cfg.output_dir, "checkpoint-1"), exist_ok=True)
    os.makedirs(os.path.join(cfg.output_dir, "checkpoint-2"), exist_ok=True)

    cfg.resume_from_checkpoint = os.path.join(cfg.output_dir, "checkpoint-1")
    cfg.auto_resume_from_checkpoints = True

    result = determine_last_checkpoint(cfg, update=True)
    # Should not override the pre-set resume checkpoint
    assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-1")
    # Return value follows function contract (returns cfg.resume_from_checkpoint)
    assert result == cfg.resume_from_checkpoint

If helpful, I can open a follow-up PR with these additions.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 84fba12 and 352a30d.

📒 Files selected for processing (2)
  • src/axolotl/utils/train.py (1 hunks)
  • tests/utils/test_train.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/utils/train.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/utils/test_train.py (2)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/train.py (1)
  • determine_last_checkpoint (11-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (2)
tests/utils/test_train.py (2)

13-20: Good coverage of numeric checkpoint sorting

Creating checkpoint-1, -9, -10, and -20 validates numeric vs. lexicographic ordering. The assertion against checkpoint-20 is correct.


9-12: temp_dir fixture verified – no changes needed

The global temp_dir fixture is defined in tests/conftest.py, so using it in test_determine_last_checkpoint is valid.

  • tests/conftest.py:421–423 → definition of temp_dir fixture

@winglian winglian merged commit ecbe8b2 into main Aug 16, 2025
17 of 18 checks passed
@winglian winglian deleted the merge-fsdp branch August 16, 2025 01:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants