feat: Implement safetensors checkpointing format support using nemo-automodel#1023
feat: Implement safetensors checkpointing format support using nemo-automodel#1023chtruong814 merged 14 commits intomainfrom
Conversation
106ebac to
6021c87
Compare
ab5ff2c to
6ba8ff1
Compare
1b4f6ed to
fd06447
Compare
fd06447 to
a4ec027
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (7)
nemo_rl/utils/automodel_checkpoint.py (7)
41-57: Make checkpoint root inference path-robustUse basename checks (and normalize) instead of string suffix to avoid false matches and handle trailing slashes.
-def _infer_checkpoint_root(weights_path: str) -> str: +def _infer_checkpoint_root(weights_path: str) -> str: @@ - weights_dir = os.path.dirname(weights_path) - if weights_dir.endswith("weights"): - return os.path.dirname(weights_dir) - return weights_dir + weights_dir = os.path.dirname(os.path.normpath(weights_path)) + if os.path.basename(weights_dir) == "weights": + return os.path.dirname(weights_dir) + return weights_dir
59-92: Tighten walk loop and silence Ruff B007Rename unused loop vars and consider early exit once both signals are found to avoid needless traversal in large directories.
- for root, dirs, files in os.walk(weights_path): - all_files.extend(files) + for _root, _dirs, files in os.walk(weights_path): + all_files.extend(files) + # micro-optimization: break if both detections are certain + if any(f.endswith(".safetensors") for f in all_files) and any( + "adapter" in f.lower() for f in all_files + ): + break
85-87: Broaden PEFT detection heuristicsAdapters may be signaled by lora/peft config files; slightly widen the check to reduce false negatives.
- if not is_peft: - is_peft = any("adapter" in f.lower() for f in all_files) + if not is_peft: + lower = [f.lower() for f in all_files] + is_peft = any( + s in fname + for fname in lower + for s in ("adapter", "lora", "peft_config.json", "adapter_config.json") + )
131-137: Map formats explicitly instead of relying on Enum member namesAvoid coupling to Enum naming by using an explicit map; reduces risk if upstream renames members.
- valid_formats = {"safetensors", "torch_save"} - if model_save_format not in valid_formats: + valid_formats = {"safetensors", "torch_save"} + if model_save_format not in valid_formats: raise ValueError( f"Unsupported model_save_format='{model_save_format}'. " f"Expected one of {sorted(valid_formats)}." )And in load (Line 205):
- format_enum = SerializationFormat[model_save_format.upper()] + fmt_map = { + "safetensors": SerializationFormat.SAFETENSORS, + "torch_save": SerializationFormat.TORCH_SAVE, + } + format_enum = fmt_map[model_save_format]Please confirm these SerializationFormat members exist in the pinned nemo-automodel version.
179-181: Prefer logging over printUse the repo’s logger for consistency and controllable verbosity.
- print(f"Saving tokenizer (or processor) to {tokenizer_path}") + logger = getattr(torch, "logger", None) or __import__("logging").getLogger(__name__) + logger.info("Saving tokenizer (or processor) to %s", tokenizer_path)
200-201: Replace prints with structured loggingSwap prints for logger.info to integrate with training logs.
- print(f"Loading weights from {weights_path}") + __import__("logging").getLogger(__name__).info("Loading weights from %s", weights_path) @@ - print(f"Loading optimizer from {optimizer_path}") + __import__("logging").getLogger(__name__).info("Loading optimizer from %s", optimizer_path)Also applies to: 234-234
209-211: Cross‑platform and robust “/model” suffix checkAvoid string path checks; use basename to handle Windows and trailing slashes.
- if not weights_path.endswith("/model"): - weights_path = os.path.join(weights_path, "model") + if os.path.basename(os.path.normpath(weights_path)) != "model": + weights_path = os.path.join(weights_path, "model")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/utils/automodel_checkpoint.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/utils/automodel_checkpoint.py (3)
nemo_rl/utils/checkpoint.py (1)
CheckpointingConfig(35-64)nemo_rl/models/policy/dtensor_policy_worker_v2.py (2)
save_checkpoint(1442-1481)load_checkpoint(1483-1495)nemo_rl/models/policy/lm_policy.py (1)
save_checkpoint(594-627)
🪛 Ruff (0.12.2)
nemo_rl/utils/automodel_checkpoint.py
75-75: Loop control variable root not used within loop body
Rename unused root to _root
(B007)
75-75: Loop control variable dirs not used within loop body
Rename unused dirs to _dirs
(B007)
133-136: Avoid specifying long messages outside the exception class
(TRY003)
168-170: Avoid specifying long messages outside the exception class
(TRY003)
231-233: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Lint check
🔇 Additional comments (3)
nemo_rl/utils/automodel_checkpoint.py (3)
138-144: LGTM: proactive dir creationCreating weights/optimizer/tokenizer dirs upfront prevents late failures. Good defensive practice.
157-176: LGTM: tokenizer handling split is correctPassing tokenizer to save_model when tokenizer_path is None and saving separately otherwise matches documented layouts.
125-155: Good: config construction and key captureValidation, key capture fallback, and root inference look sane. Ensure DTensor v2 passes model_state_dict_keys to avoid relying on the fallback under parallelism.
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com> Signed-off-by: Felipe Vieira Frujeri <ffrujeri@gmail.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 5470a29 (PR #1023 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (5)
tests/unit/L0_Unit_Tests_Other.sh (2)
23-29: Guard pytest --collect-only under set -e, handle non‑zero codes, and fix ignore paths (mcore).With set -e, pytest exit 5 (no tests) can abort the script before branching; also the ignore paths should be rooted under tests/unit/ for this direct pytest call. Handle rc==0 vs rc==5 explicitly and fail on other rc. Add --no-sync to keep CI deterministic.
-# Check and run mcore tests -exit_code=$(pytest tests/unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --collect-only --hf-gated --mcore-only -q >/dev/null 2>&1; echo $?) -if [[ $exit_code -eq 5 ]]; then - echo "No mcore tests to run" -else - uv run --extra mcore bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --mcore-only -fi +# Check and run mcore tests +set +e +pytest tests/unit/ \ + --ignore=tests/unit/models/generation/ \ + --ignore=tests/unit/models/policy/ \ + --collect-only --hf-gated --mcore-only -q >/dev/null 2>&1 +collect_rc=$? +set -e +if [[ $collect_rc -eq 5 ]]; then + echo "No mcore tests to run" +elif [[ $collect_rc -eq 0 ]]; then + uv run --no-sync --extra mcore bash -x ./tests/run_unit.sh unit/ \ + --ignore=unit/models/generation/ --ignore=unit/models/policy/ \ + --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json \ + --hf-gated --mcore-only +else + echo "mcore test collection failed (exit $collect_rc)" + exit "$collect_rc" +fi
31-37: Apply the same safety and path fixes to automodel block.Mirror the set -e guard, correct ignore roots for direct pytest, handle rc, and add --no-sync on the uv run.
-# Check and run automodel tests -exit_code=$(pytest tests/unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --collect-only --hf-gated --automodel-only -q >/dev/null 2>&1; echo $?) -if [[ $exit_code -eq 5 ]]; then - echo "No automodel tests to run" -else - uv run --extra automodel bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel-only -fi +# Check and run automodel tests +set +e +pytest tests/unit/ \ + --ignore=tests/unit/models/generation/ \ + --ignore=tests/unit/models/policy/ \ + --collect-only --hf-gated --automodel-only -q >/dev/null 2>&1 +collect_rc=$? +set -e +if [[ $collect_rc -eq 5 ]]; then + echo "No automodel tests to run" +elif [[ $collect_rc -eq 0 ]]; then + uv run --no-sync --extra automodel bash -x ./tests/run_unit.sh unit/ \ + --ignore=unit/models/generation/ --ignore=unit/models/policy/ \ + --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json \ + --hf-gated --automodel-only +else + echo "automodel test collection failed (exit $collect_rc)" + exit "$collect_rc" +fitests/unit/utils/test_automodel_checkpoint.py (1)
343-347: Accept multiple torch_save artifact extensions.DCP may emit .distcp while other backends use .bin/.pt/.pth. Broaden the check.
- assert any(f.endswith(".distcp") for f in files) + assert any( + f.endswith(ext) for f in files for ext in (".distcp", ".bin", ".pt", ".pth") + )nemo_rl/models/policy/dtensor_policy_worker_v2.py (2)
242-250: Same torch_dtype fix for from_config path.self.model = model_class.from_config( model_config, attn_implementation="flash_attention_2" if self.enable_seq_packing else None, use_liger_kernel=False, trust_remote_code=True, - torch_dtype=str(model_config.torch_dtype), + torch_dtype=model_config.torch_dtype, )
222-227: Pass dtype object, not string, to torch_dtype.Transformers expects torch.dtype or "auto"; str(torch.float32) may break. Same fix below in from_config.
model = model_class.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially trust_remote_code=True, config=model_config, use_liger_kernel=False, - torch_dtype=str(model_config.torch_dtype), + torch_dtype=model_config.torch_dtype, )
🧹 Nitpick comments (2)
nemo_rl/utils/automodel_checkpoint.py (2)
41-56: Make checkpoint root inference robust to path shapes.endswith("weights") can misfire; use basename logic and handle both “…/weights” and “…/weights/model”.
+from pathlib import Path @@ -def _infer_checkpoint_root(weights_path: str) -> str: - """Infer checkpoint root directory from weights path. - - When weights_path ends with "…/weights/model", we need the parent of - the weights directory (the checkpoint root), not the weights directory itself. - - Args: - weights_path: Path to model weights (e.g., "/path/to/policy/weights/model") - - Returns: - str: Checkpoint root directory (e.g., "/path/to/policy") - """ - weights_dir = os.path.dirname(weights_path) - if weights_dir.endswith("weights"): - return os.path.dirname(weights_dir) - return weights_dir +def _infer_checkpoint_root(weights_path: str) -> str: + """Return checkpoint root given weights dir or weights/model dir.""" + p = Path(weights_path) + if p.name == "model" and p.parent.name == "weights": + return str(p.parent.parent) + if p.name == "weights": + return str(p.parent) + return str(p.parent)
207-211: Avoid string-suffix path check; use basename.Ensures Windows/posix compatibility and avoids trailing-slash pitfalls.
- if not weights_path.endswith("/model"): - weights_path = os.path.join(weights_path, "model") + if os.path.basename(weights_path) != "model": + weights_path = os.path.join(weights_path, "model")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (21)
3rdparty/Automodel-workspace/Automodel(1 hunks)README.md(2 hunks)docker/Dockerfile(1 hunks)examples/configs/grpo_math_1B.yaml(3 hunks)nemo_rl/algorithms/dpo.py(1 hunks)nemo_rl/algorithms/grpo.py(1 hunks)nemo_rl/algorithms/rm.py(1 hunks)nemo_rl/algorithms/sft.py(1 hunks)nemo_rl/models/policy/dtensor_policy_worker_v2.py(6 hunks)nemo_rl/models/policy/lm_policy.py(4 hunks)nemo_rl/utils/automodel_checkpoint.py(1 hunks)nemo_rl/utils/checkpoint.py(4 hunks)pyproject.toml(1 hunks)pyrefly.toml(1 hunks)tests/functional/L1_Functional_Tests_GPU.sh(1 hunks)tests/functional/test_automodel_extra_installed_correctly.sh(1 hunks)tests/unit/L0_Unit_Tests_Generation.sh(1 hunks)tests/unit/L0_Unit_Tests_Other.sh(1 hunks)tests/unit/L0_Unit_Tests_Policy.sh(1 hunks)tests/unit/conftest.py(1 hunks)tests/unit/utils/test_automodel_checkpoint.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (16)
- 3rdparty/Automodel-workspace/Automodel
- pyrefly.toml
- nemo_rl/algorithms/sft.py
- nemo_rl/algorithms/grpo.py
- examples/configs/grpo_math_1B.yaml
- nemo_rl/algorithms/dpo.py
- nemo_rl/algorithms/rm.py
- docker/Dockerfile
- pyproject.toml
- tests/functional/test_automodel_extra_installed_correctly.sh
- tests/unit/conftest.py
- tests/unit/L0_Unit_Tests_Generation.sh
- tests/functional/L1_Functional_Tests_GPU.sh
- README.md
- nemo_rl/utils/checkpoint.py
- tests/unit/L0_Unit_Tests_Policy.sh
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-17T01:52:21.380Z
Learnt from: ffrujeri
PR: NVIDIA-NeMo/RL#1023
File: nemo_rl/utils/checkpoint.py:58-65
Timestamp: 2025-09-17T01:52:21.380Z
Learning: model_state_dict_keys is not intended to be part of the nemo-rl CheckpointingConfig TypedDict - it's handled at the automodel implementation layer, not as a general checkpointing configuration parameter.
Applied to files:
nemo_rl/models/policy/dtensor_policy_worker_v2.py
🧬 Code graph analysis (4)
tests/unit/utils/test_automodel_checkpoint.py (3)
nemo_rl/utils/automodel_checkpoint.py (3)
detect_checkpoint_format(59-91)load_checkpoint(184-240)save_checkpoint(94-181)nemo_rl/models/policy/dtensor_policy_worker_v2.py (2)
load_checkpoint(1483-1495)save_checkpoint(1442-1481)nemo_rl/models/policy/lm_policy.py (1)
save_checkpoint(594-627)
nemo_rl/utils/automodel_checkpoint.py (3)
nemo_rl/utils/checkpoint.py (1)
CheckpointingConfig(35-64)nemo_rl/models/policy/dtensor_policy_worker_v2.py (2)
save_checkpoint(1442-1481)load_checkpoint(1483-1495)nemo_rl/models/policy/lm_policy.py (1)
save_checkpoint(594-627)
nemo_rl/models/policy/lm_policy.py (2)
nemo_rl/utils/checkpoint.py (1)
CheckpointingConfig(35-64)nemo_rl/distributed/worker_groups.py (1)
run_all_workers_single_data(705-749)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (4)
nemo_rl/utils/automodel_checkpoint.py (2)
load_checkpoint(184-240)save_checkpoint(94-181)nemo_rl/models/policy/dtensor_policy_worker.py (2)
load_checkpoint(1507-1517)save_checkpoint(1487-1505)nemo_rl/models/policy/lm_policy.py (1)
save_checkpoint(594-627)nemo_rl/utils/checkpoint.py (1)
CheckpointingConfig(35-64)
🪛 Ruff (0.12.2)
tests/unit/utils/test_automodel_checkpoint.py
25-25: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
98-98: Avoid equality comparisons to False; use not is_peft: for false checks
Replace with not is_peft
(E712)
117-117: Avoid equality comparisons to False; use not is_peft: for false checks
Replace with not is_peft
(E712)
131-131: Avoid equality comparisons to False; use not is_peft: for false checks
Replace with not is_peft
(E712)
145-145: Avoid equality comparisons to True; use is_peft: for truth checks
Replace with is_peft
(E712)
152-152: Avoid equality comparisons to False; use not is_peft: for false checks
Replace with not is_peft
(E712)
179-179: Avoid equality comparisons to False; use not is_peft: for false checks
Replace with not is_peft
(E712)
209-209: Avoid equality comparisons to False; use not call_args[1]["checkpoint_config"].is_peft: for false checks
Replace with not call_args[1]["checkpoint_config"].is_peft
(E712)
249-249: Unused method argument: mock_save_model
(ARG002)
nemo_rl/utils/automodel_checkpoint.py
75-75: Loop control variable root not used within loop body
Rename unused root to _root
(B007)
75-75: Loop control variable dirs not used within loop body
Rename unused dirs to _dirs
(B007)
133-136: Avoid specifying long messages outside the exception class
(TRY003)
168-170: Avoid specifying long messages outside the exception class
(TRY003)
231-233: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/models/policy/lm_policy.py
82-85: Avoid specifying long messages outside the exception class
(TRY003)
97-100: Avoid specifying long messages outside the exception class
(TRY003)
618-620: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/models/policy/dtensor_policy_worker_v2.py
1454-1456: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Lint check
🔇 Additional comments (8)
nemo_rl/utils/automodel_checkpoint.py (1)
75-77: Silence unused loop vars in os.walk.Rename to underscores to satisfy linters.
[raise_nitpick_refactor]
- for root, dirs, files in os.walk(weights_path): + for _root, _dirs, files in os.walk(weights_path): all_files.extend(files)tests/unit/utils/test_automodel_checkpoint.py (3)
25-27: Remove unused noqa by aliasing import.Cleaner than suppressing F401.
[raise_nitpick_refactor]
- import nemo_automodel # noqa: F401 + import nemo_automodel as _nemo_automodel
96-99: Prefer boolean style in asserts.Use truthiness checks.
[raise_nitpick_refactor]
- assert is_peft == False + assert not is_peft @@ - assert is_peft == False + assert not is_peft @@ - assert is_peft == False + assert not is_peft @@ - assert is_peft == False + assert not is_peft @@ - assert is_peft == False + assert not is_peftAlso applies to: 115-118, 129-132, 150-153, 177-180
248-256: Address unused patched arg.Prefix to avoid ARG002.
[raise_nitpick_refactor]
- def test_save_with_tokenizer(self, mock_save_model, mock_model): + def test_save_with_tokenizer(self, _mock_save_model, mock_model):nemo_rl/models/policy/lm_policy.py (2)
79-85: Good: explicit backend mutual exclusion.Raising ValueError prevents ambiguous setup.
614-621: Good: guard safetensors to DTensor v2 only.Prevents unsupported code paths.
Please confirm docs/examples flag this constraint in user-facing configs.
nemo_rl/models/policy/dtensor_policy_worker_v2.py (2)
358-362: LGTM: broadcast original state_dict keys across ranks.Keeps consistent key order for saving.
1453-1469: Requiring checkpointing_cfg is appropriate; ensure all callers pass it.Algorithms appear to provide master_config["checkpointing"]; keep this invariant.
Consider a short error hint: “Did you forget to pass policy.save_checkpoint(..., checkpointing_cfg=master_config['checkpointing'])?”
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
|
❌ Submodule Fast-Forward Check FailedCheck based on commit: 87ec980 (PR #1023 from ❌ Submodules that need attention:Automodel: ❌ Commits have DIVERGED from a common ancestor Please ensure all submodule commits are fast-forwards of the main branch before merging. |
|
this commit passed as part of this merge queue run: https://github.com/NVIDIA-NeMo/RL/actions/runs/17828559136. will manually merge this one |
What does this PR do ?
This PR implements the adaptor automodel_checkpoint.py in nemo-rl. Introducing checkpointing functionality from nemo-automodel APIs and makes checkpointing configuration accessible through the DTensorPolicyWorkerV2.
Tha native_checkpoint.py functionality is preserved for now and should maybe be deprecated in the future. DTensorPolicyWorkerV1. Consumes it.
The current native_checkpoint structure is the following
The one produced by the automodel_checkpoint.py module is
Issues
#578
Usage
Before your PR is "Ready for review"
Pre checks:
Summary by CodeRabbit
New Features
Improvements
Documentation
Tests
Chores