From 65ce317d4a1a07ed28cf8d44579eeb06e544d9dc Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Wed, 10 Sep 2025 23:57:20 -0700 Subject: [PATCH 01/14] Implement automodel_checkpoint adapter with NeMo-Automodel primitives. Signed-off-by: Felipe Vieira Frujeri --- README.md | 7 +- examples/configs/grpo_math_1B.yaml | 6 +- nemo_rl/algorithms/dpo.py | 1 + nemo_rl/algorithms/grpo.py | 1 + nemo_rl/algorithms/rm.py | 1 + nemo_rl/algorithms/sft.py | 1 + .../models/policy/dtensor_policy_worker_v2.py | 27 +- nemo_rl/models/policy/lm_policy.py | 52 ++- nemo_rl/utils/automodel_checkpoint.py | 213 +++++++++ nemo_rl/utils/checkpoint.py | 39 +- pyrefly.toml | 1 + tests/unit/utils/test_automodel_checkpoint.py | 408 ++++++++++++++++++ 12 files changed, 728 insertions(+), 29 deletions(-) create mode 100644 nemo_rl/utils/automodel_checkpoint.py create mode 100644 tests/unit/utils/test_automodel_checkpoint.py diff --git a/README.md b/README.md index 77ec8274eb..c7e47203fd 100644 --- a/README.md +++ b/README.md @@ -76,11 +76,10 @@ What you can expect: Clone **NeMo RL**. ```sh -git clone git@github.com:NVIDIA-NeMo/RL.git nemo-rl +git clone git@github.com:NVIDIA-NeMo/RL.git nemo-rl --recursive cd nemo-rl -# If you are using the Megatron backend, download the pinned versions of Megatron-LM and NeMo submodules -# by running (This is not necessary if you are using the pure Pytorch/DTensor path): +# If you are already cloned without the recursive option, you can initialize the submodules recursively git submodule update --init --recursive # Different branches of the repo can have different pinned versions of these third-party submodules. Ensure @@ -127,7 +126,7 @@ bash tools/build-flash-attn-in-uv-cache.sh > The NeMo RL Dockerfile will warm the uv cache with flash-attn. > See https://docs.nvidia.com/nemo/rl/latest/docker.html for instructions if you are looking for the NeMo RL container. -If sucessful, you should see `✅ flash-attn successfully added to uv cache`. +If successful, you should see `✅ flash-attn successfully added to uv cache`. Use `uv run` to launch all commands. It handles pip installing implicitly and ensures your environment is up to date with our lock file. > [!NOTE] diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 6ba7e4d54b..09b381bfff 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -33,6 +33,8 @@ checkpointing: keep_top_k: 3 save_period: 10 checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false policy: model_name: "Qwen/Qwen2.5-1.5B" @@ -52,7 +54,7 @@ policy: cpu_offload: False sequence_parallel: false activation_checkpointing: false - tensor_parallel_size: 1 + tensor_parallel_size: 2 context_parallel_size: 1 custom_parallel_plan: null @@ -228,5 +230,5 @@ logger: flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: - gpus_per_node: 1 + gpus_per_node: 2 num_nodes: 1 diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index b22f059f74..0a25ab2552 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -651,6 +651,7 @@ def dpo_train( tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), + checkpointing_cfg=master_config["checkpointing"], ) torch.save( train_dataloader.state_dict(), diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e9bc749d60..595599fe16 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -844,6 +844,7 @@ def grpo_train( tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), + checkpointing_cfg=master_config["checkpointing"], ) torch.save( dataloader.state_dict(), diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index ad646d0021..e59d501c1b 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -583,6 +583,7 @@ def rm_train( tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), + checkpointing_cfg=master_config["checkpointing"], ) torch.save( train_dataloader.state_dict(), diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 71853294d9..dfbaf14ef7 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -523,6 +523,7 @@ def sft_train( tokenizer_path=os.path.join( checkpoint_path, "policy", "tokenizer" ), + checkpointing_cfg=master_config["checkpointing"], ) torch.save( train_dataloader.state_dict(), diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 18f26afb98..f76f9e30ec 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -84,10 +84,11 @@ import_class_from_path, resolve_model_class, ) -from nemo_rl.utils.native_checkpoint import ( +from nemo_rl.utils.automodel_checkpoint import ( load_checkpoint, save_checkpoint, ) +from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name @@ -1433,11 +1434,30 @@ def save_checkpoint( weights_path: str, optimizer_path: Optional[str] = None, tokenizer_path: Optional[str] = None, + checkpointing_cfg: Optional[CheckpointingConfig] = None, ) -> None: """Save a checkpoint of the model. the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. """ + if checkpointing_cfg is None: + raise ValueError( + "checkpointing_cfg must be provided when saving checkpoint" + ) + + # Extract only the checkpointing configuration keys that exist + checkpoint_kwargs = { + key: value + for key, value in checkpointing_cfg.items() + if key + in { + "model_save_format", + "save_consolidated", + "is_peft", + "peft_config", + } + } + save_checkpoint( model=self.model, weights_path=weights_path, @@ -1446,10 +1466,13 @@ def save_checkpoint( optimizer_path=optimizer_path, tokenizer=self.tokenizer if tokenizer_path else None, tokenizer_path=tokenizer_path, + **checkpoint_kwargs, ) def load_checkpoint( - self, weights_path: str, optimizer_path: Optional[str] = None + self, + weights_path: str, + optimizer_path: Optional[str] = None, ) -> None: """Load a checkpoint into the model.""" load_checkpoint( diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index c981982f51..4f082e74f6 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -42,6 +42,7 @@ LogprobOutputSpec, ReferenceLogprobOutputSpec, ) +from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.flops_tracker import ( FLOPTracker, get_default_hf_config, @@ -75,7 +76,13 @@ def __init__( pp_size = 1 cp_size = 1 - megatron_enable = "megatron_cfg" in config and config["megatron_cfg"]["enabled"] + megatron_enable = bool(config.get("megatron_cfg", {}).get("enabled", False)) + dtensor_enable = bool(config.get("dtensor_cfg", {}).get("enabled", False)) + if megatron_enable and dtensor_enable: + raise ValueError( + "Configure either Megatron (policy.megatron_cfg.enabled=true) or " + "DTensor (policy.dtensor_cfg.enabled=true), not both." + ) if megatron_enable: worker_builder_cls = ( "nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker" @@ -86,13 +93,14 @@ def __init__( env_vars = config["megatron_cfg"].get("env_vars", {}) else: - assert config["dtensor_cfg"]["enabled"], ( - "Please either set policy.megatron_cfg.enabled=true to use Megatron training backend " - "or set policy.dtensor_cfg.enabled=true to use DTensor training backend." - ) + if not dtensor_enable: + raise ValueError( + "Please either set policy.megatron_cfg.enabled=true to use Megatron training backend " + "or set policy.dtensor_cfg.enabled=true to use DTensor training backend." + ) # Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility) - use_v2 = config["dtensor_cfg"].get("_v2", False) + use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) if use_v2: worker_builder_cls = "nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" else: @@ -588,14 +596,34 @@ def save_checkpoint( weights_path: str, optimizer_path: Optional[str] = None, tokenizer_path: Optional[str] = None, + checkpointing_cfg: Optional[CheckpointingConfig] = None, ) -> None: """Save a checkpoint of the model.""" - futures = self.worker_group.run_all_workers_single_data( - "save_checkpoint", - weights_path=weights_path, - optimizer_path=optimizer_path, - tokenizer_path=tokenizer_path, - ) + # Only pass checkpointing_cfg for DTensor v2 + use_v2 = self.cfg.get("dtensor_cfg", {}).get("_v2", False) + + if use_v2: + futures = self.worker_group.run_all_workers_single_data( + "save_checkpoint", + weights_path=weights_path, + optimizer_path=optimizer_path, + tokenizer_path=tokenizer_path, + checkpointing_cfg=checkpointing_cfg, + ) + else: + if ( + checkpointing_cfg is not None + and checkpointing_cfg.get("model_save_format") == "safetensors" + ): + raise ValueError( + "safetensors is only supported with DTensorPolicyWorkerV2 (_v2=true)." + ) + futures = self.worker_group.run_all_workers_single_data( + "save_checkpoint", + weights_path=weights_path, + optimizer_path=optimizer_path, + tokenizer_path=tokenizer_path, + ) ray.get(futures) def shutdown(self) -> bool: diff --git a/nemo_rl/utils/automodel_checkpoint.py b/nemo_rl/utils/automodel_checkpoint.py new file mode 100644 index 0000000000..ed1f743f47 --- /dev/null +++ b/nemo_rl/utils/automodel_checkpoint.py @@ -0,0 +1,213 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint management utilities for HF models.""" + +import os +from typing import Any, Optional + +import torch + +# Apply torch backports for compatibility with torch==2.7.1 +from nemo_automodel.components.checkpoint._torch_backports import apply_patches + +# Import from nemo-automodel +from nemo_automodel.components.checkpoint.checkpointing import ( + CheckpointingConfig, + load_model, + load_optimizer, + save_model, + save_optimizer, +) + +# Apply torch backports for compatibility with torch==2.7.1 +apply_patches() + + +def detect_checkpoint_format(weights_path: str) -> tuple[str, bool]: + """Detect model save format and PEFT status from checkpoint directory. + + Args: + weights_path: Path to the checkpoint directory (e.g., weights/model) + + Returns: + tuple: (model_save_format, is_peft) where: + model_save_format is "torch_save" for DCP or "safetensors" for safetensors + is_peft is True if PEFT/adapter patterns are detected + """ + is_peft = False + model_save_format = "safetensors" + try: + # Iterate through all subdirectories and files recursively + all_files = [] + for root, dirs, files in os.walk(weights_path): + all_files.extend(files) + + if any(f.endswith(".distcp") for f in all_files): + model_save_format = "torch_save" + elif any(f.endswith(".safetensors") for f in all_files): + model_save_format = "safetensors" + elif any(f.endswith((".bin", ".pt", ".pth")) for f in all_files): + model_save_format = "torch_save" + + if not is_peft: + is_peft = any("adapter" in f.lower() for f in all_files) + + except (OSError, PermissionError): + pass + + return model_save_format, is_peft + + +def save_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, + tokenizer: Optional[Any] = None, + tokenizer_path: Optional[str] = None, + model_save_format: str = "safetensors", + is_peft: bool = False, + peft_config: Optional[Any] = None, + save_consolidated: bool = False, +) -> None: + """Save a checkpoint of the model and optionally optimizer state. + + Args: + model: The PyTorch model to save + weights_path: Path to save model weights + optimizer: Optional optimizer to save + scheduler: Optional scheduler to save + optimizer_path: Path to save optimizer state (required if optimizer provided) + tokenizer: Optional tokenizer to save + tokenizer_path: Path to save tokenizer state (required if tokenizer provided) + model_save_format: Format for saving model ("torch_save" or "safetensors") + is_peft: Whether the model uses PEFT + peft_config: PEFT configuration if is_peft is True + """ + # Create checkpoint config + + valid_formats = {"safetensors", "torch_save"} + if model_save_format not in valid_formats: + raise ValueError( + f"Unsupported model_save_format='{model_save_format}'. " + f"Expected one of {sorted(valid_formats)}." + ) + + # Ensure target directories exist + os.makedirs(weights_path, exist_ok=True) + if optimizer_path: + os.makedirs(optimizer_path, exist_ok=True) + if tokenizer_path: + os.makedirs(tokenizer_path, exist_ok=True) + + checkpoint_config = CheckpointingConfig( + enabled=True, + checkpoint_dir=os.path.dirname(weights_path), + model_save_format=model_save_format, + model_cache_dir="", + model_repo_id="", + save_consolidated=save_consolidated, + is_peft=is_peft, + ) + + # Save model using nemo-automodel API + save_model( + model=model, + weights_path=weights_path, + checkpoint_config=checkpoint_config, + peft_config=peft_config, + tokenizer=tokenizer if tokenizer_path is None else None, + ) + + # Save optimizer if provided + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when saving optimizer state" + ) + save_optimizer( + optimizer=optimizer, + model=model, + weights_path=optimizer_path, + scheduler=scheduler, + ) + + # Save tokenizer separately if tokenizer_path provided + if tokenizer is not None and tokenizer_path is not None: + print(f"Saving tokenizer (or processor) to {tokenizer_path}") + tokenizer.save_pretrained(tokenizer_path) + + +def load_checkpoint( + model: torch.nn.Module, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, +) -> None: + """Load a model weights and optionally optimizer state. + + Args: + model: The PyTorch model whose weights to update + weights_path: Path to load model weights from + optimizer: Optional optimizer to load state into + scheduler: Optional scheduler to load state into + optimizer_path: Path to load optimizer state from (required if optimizer provided) + model_save_format: Format model was saved in ("torch_save" or "safetensors") + is_peft: Whether the model uses PEFT + """ + print(f"Loading weights from {weights_path}") + + model_save_format, is_peft = detect_checkpoint_format(weights_path) + checkpoint_config = CheckpointingConfig( + enabled=True, + checkpoint_dir=os.path.dirname(weights_path), + model_save_format=model_save_format, + model_cache_dir="", # Not used for basic loading + model_repo_id="", # Not used for basic loading + save_consolidated=False, # Keep original behavior + is_peft=is_peft, + ) + + try: + # Load model using nemo-automodel API + load_model( + model=model, + weights_path=weights_path, + checkpoint_config=checkpoint_config, + ) + except FileNotFoundError as e: + msg = ( + f"Failed to load model from '{weights_path}': {e}\n" + "Note: DTensorPolicyWorkerV2 expects:\n" + " - Model shards under '/weights/model'\n" + " - Optimizer states under '/optimizer/optim'\n" + "Please verify your checkpoint layout." + ) + raise FileNotFoundError(msg) from e + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when loading optimizer state" + ) + print(f"Loading optimizer from {optimizer_path}") + load_optimizer( + optimizer=optimizer, + model=model, + weights_path=optimizer_path, + scheduler=scheduler, + ) diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index 6f84d7782f..0c1b68a54e 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -41,6 +41,11 @@ class CheckpointingConfig(TypedDict): metric_name (str | None): Name of the metric to use for determining best checkpoints. higher_is_better (bool): Whether higher values of the metric indicate better performance. keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept. + model_save_format (str): Format for saving model ("torch_save" or "safetensors"). + save_consolidated (bool): Whether to save consolidated checkpoints (for HF compatibility). + model_cache_dir (str): Directory for model cache (for safetensors format). + model_repo_id (str): Repository ID for the model (for safetensors format). + is_peft (bool): Whether the model uses PEFT. """ enabled: bool @@ -50,6 +55,13 @@ class CheckpointingConfig(TypedDict): save_period: int keep_top_k: NotRequired[int] checkpoint_must_save_by: NotRequired[str | None] + # New nemo-automodel integration fields + model_save_format: NotRequired[str] # Default: "safetensors" + save_consolidated: NotRequired[bool] # Default: False + model_cache_dir: NotRequired[str] # Default: "" + model_repo_id: NotRequired[str] # Default: "" + is_peft: NotRequired[bool] # Default: False + peft_config: NotRequired[Any] # Default: None class CheckpointManager: @@ -84,6 +96,13 @@ def __init__(self, config: CheckpointingConfig): self.higher_is_better = config["higher_is_better"] self.keep_top_k = config["keep_top_k"] + # Store nemo-automodel specific config options + self.model_save_format = config.get("model_save_format", "safetensors") + self.save_consolidated = config.get("save_consolidated", False) + self.model_cache_dir = config.get("model_cache_dir", "") + self.model_repo_id = config.get("model_repo_id", "") + self.is_peft = config.get("is_peft", False) + def init_tmp_checkpoint( self, step: int, @@ -113,10 +132,11 @@ def init_tmp_checkpoint( # save training info with open(save_dir / "training_info.json", "w") as f: # make any numpy items serializable - for k, v in training_info.items(): + serializable_training_info = dict(training_info) + for k, v in serializable_training_info.items(): if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray): - training_info[k] = v.item() - json.dump(training_info, f) + serializable_training_info[k] = v.item() + json.dump(serializable_training_info, f) # save config if run_config is not None: @@ -182,18 +202,18 @@ def remove_old_checkpoints(self, exclude_latest: bool = True) -> None: checkpoint_history.sort(key=lambda x: x[0], reverse=True) else: try: - assert self.metric_name is not None # Type checker hint + metric_name = ( + self.metric_name + ) # Type checker hint - capture the non-None value # sort by metric value first, then by step number (for equal metrics, prefer more recent) if self.higher_is_better: # For higher_is_better=True: higher metric values first, then higher step numbers checkpoint_history.sort( - key=lambda x: (x[2][self.metric_name], x[0]), reverse=True + key=lambda x: (x[2][metric_name], x[0]), reverse=True ) else: # For higher_is_better=False: lower metric values first, then higher step numbers for equal values - checkpoint_history.sort( - key=lambda x: (x[2][self.metric_name], -x[0]) - ) + checkpoint_history.sort(key=lambda x: (x[2][metric_name], -x[0])) except KeyError: warnings.warn( f"Metric {self.metric_name} not found in checkpoint history. Keeping most recent k checkpoints." @@ -230,8 +250,9 @@ def get_best_checkpoint_path(self) -> Optional[str]: ) return self.get_latest_checkpoint_path() + metric_name = self.metric_name # Type checker hint - capture the non-None value checkpoint_history.sort( - key=lambda x: x[2][self.metric_name], reverse=self.higher_is_better + key=lambda x: x[2][metric_name], reverse=self.higher_is_better ) return str(checkpoint_history[0][1]) diff --git a/pyrefly.toml b/pyrefly.toml index 51b5574002..5f320107c5 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -108,6 +108,7 @@ project-includes = [ "nemo_rl/utils/checkpoint.py", "nemo_rl/utils/config.py", "nemo_rl/utils/native_checkpoint.py", + "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/nsys.py", "nemo_rl/utils/nvml.py", "nemo_rl/utils/prefetch_venvs.py", diff --git a/tests/unit/utils/test_automodel_checkpoint.py b/tests/unit/utils/test_automodel_checkpoint.py new file mode 100644 index 0000000000..15e5956343 --- /dev/null +++ b/tests/unit/utils/test_automodel_checkpoint.py @@ -0,0 +1,408 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from nemo_rl.utils.automodel_checkpoint import ( + detect_checkpoint_format, + load_checkpoint, + save_checkpoint, +) + + +class TestModel(torch.nn.Module): + """Simple test model with a forward method.""" + + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [ + torch.nn.Linear(4, 4), + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +@pytest.fixture +def mock_model(): + """Create a simple mock model for testing.""" + return TestModel() + + +@pytest.fixture +def mock_optimizer(): + """Create a simple mock optimizer for testing.""" + model = torch.nn.Linear(4, 1) + return torch.optim.Adam(model.parameters()) + + +class TestDetectCheckpointFormat: + """Test the detect_checkpoint_format function.""" + + def test_directory_with_safetensors(self): + """Test detection for directories containing safetensors files.""" + with TemporaryDirectory() as tmp_dir: + # Create directory with safetensors files + os.makedirs(os.path.join(tmp_dir, "weights", "model")) + weights_path = os.path.join(tmp_dir, "weights", "model") + + # Create safetensors shard files + with open( + os.path.join( + weights_path, "shard-00001-model-00001-of-00001.safetensors" + ), + "w", + ) as f: + f.write("dummy content") + with open( + os.path.join( + weights_path, "shard-00002-model-00001-of-00001.safetensors" + ), + "w", + ) as f: + f.write("dummy content") + + format_type, is_peft = detect_checkpoint_format(weights_path) + assert format_type == "safetensors" + assert is_peft == False + + def test_directory_with_dcp_format(self): + """Test detection for directories with DCP (Distributed Checkpoint) format.""" + with TemporaryDirectory() as tmp_dir: + # Create directory structure like: step_3/policy/optimizer/optim + optim_path = os.path.join(tmp_dir, "step_3", "policy", "optimizer", "optim") + os.makedirs(optim_path) + + # Create DCP files (.distcp + .metadata) + with open(os.path.join(optim_path, "__0_0.distcp"), "w") as f: + f.write("dummy dcp content") + with open(os.path.join(optim_path, "__1_0.distcp"), "w") as f: + f.write("dummy dcp content") + with open(os.path.join(optim_path, ".metadata"), "w") as f: + f.write("dummy metadata") + + format_type, is_peft = detect_checkpoint_format(optim_path) + assert format_type == "torch_save" # DCP uses torch_save format + assert is_peft == False + + def test_directory_with_torch_files(self): + """Test detection for directories containing torch save files.""" + with TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "model") + os.makedirs(model_path) + + # Create torch save files + with open(os.path.join(model_path, "pytorch_model.bin"), "w") as f: + f.write("dummy content") + + format_type, is_peft = detect_checkpoint_format(model_path) + assert format_type == "torch_save" + assert is_peft == False + + def test_peft_detection_in_filenames(self): + """Test PEFT detection from filenames within directories.""" + with TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "regular_model") + os.makedirs(model_path) + + # Create file with adapter pattern in name + with open(os.path.join(model_path, "adapter_model.safetensors"), "w") as f: + f.write("dummy content") + + format_type, is_peft = detect_checkpoint_format(model_path) + assert format_type == "safetensors" + assert is_peft == True # Should detect adapter in filename + + def test_default_fallback(self): + """Test default behavior for non-existent directories.""" + # Non-existent directory should default to safetensors, no PEFT + format_type, is_peft = detect_checkpoint_format("/non/existent/directory") + assert format_type == "safetensors" + assert is_peft == False + + def test_expected_structure(self): + """Test with the expected folder structure from the user.""" + with TemporaryDirectory() as tmp_dir: + # Create the expected structure: step_3/policy/weights/model + weights_path = os.path.join(tmp_dir, "step_3", "policy", "weights", "model") + os.makedirs(weights_path) + + # Create safetensors shard files as in the example + with open( + os.path.join( + weights_path, "shard-00001-model-00001-of-00001.safetensors" + ), + "w", + ) as f: + f.write("dummy content") + with open( + os.path.join( + weights_path, "shard-00002-model-00001-of-00001.safetensors" + ), + "w", + ) as f: + f.write("dummy content") + + format_type, is_peft = detect_checkpoint_format(weights_path) + assert format_type == "safetensors" + assert is_peft == False + + """Test the save_checkpoint function.""" + + @patch("nemo_rl.utils.automodel_checkpoint.save_model") + @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") + def test_save_model_only(self, mock_save_optimizer, mock_save_model, mock_model): + """Test saving model weights only.""" + with TemporaryDirectory() as tmp_dir: + weights_path = os.path.join(tmp_dir, "weights") + os.makedirs(os.path.dirname(weights_path), exist_ok=True) + + # Save checkpoint + save_checkpoint( + model=mock_model, + weights_path=weights_path, + model_save_format="safetensors", + is_peft=False, + ) + + # Verify save_model was called correctly + mock_save_model.assert_called_once() + call_args = mock_save_model.call_args + assert call_args[1]["model"] is mock_model + assert call_args[1]["weights_path"] == weights_path + assert ( + call_args[1]["checkpoint_config"].model_save_format.value + == "safetensors" + ) + assert call_args[1]["checkpoint_config"].is_peft == False + + # Verify optimizer saving was not called + mock_save_optimizer.assert_not_called() + + @patch("nemo_rl.utils.automodel_checkpoint.save_model") + @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") + def test_save_with_optimizer( + self, mock_save_optimizer, mock_save_model, mock_model, mock_optimizer + ): + """Test saving model and optimizer weights.""" + with TemporaryDirectory() as tmp_dir: + weights_path = os.path.join(tmp_dir, "model", "weights") + optimizer_path = os.path.join(tmp_dir, "optimizer", "optim") + os.makedirs(os.path.dirname(weights_path)) + os.makedirs(os.path.dirname(optimizer_path)) + + # Save checkpoint with optimizer + save_checkpoint( + model=mock_model, + weights_path=weights_path, + optimizer=mock_optimizer, + optimizer_path=optimizer_path, + model_save_format="torch_save", + is_peft=True, + ) + + # Verify both model and optimizer saving were called + mock_save_model.assert_called_once() + mock_save_optimizer.assert_called_once() + + # Check optimizer call args + opt_call_args = mock_save_optimizer.call_args + assert opt_call_args[1]["optimizer"] is mock_optimizer + assert opt_call_args[1]["model"] is mock_model + assert opt_call_args[1]["weights_path"] == optimizer_path + + @patch("nemo_rl.utils.automodel_checkpoint.save_model") + def test_save_with_tokenizer(self, mock_save_model, mock_model): + """Test saving with tokenizer.""" + with TemporaryDirectory() as tmp_dir: + weights_path = os.path.join(tmp_dir, "model", "weights") + tokenizer_path = os.path.join(tmp_dir, "tokenizer") + os.makedirs(os.path.dirname(weights_path)) + os.makedirs(tokenizer_path) + + # Create mock tokenizer + mock_tokenizer = MagicMock() + + # Save checkpoint with tokenizer + save_checkpoint( + model=mock_model, + weights_path=weights_path, + tokenizer=mock_tokenizer, + tokenizer_path=tokenizer_path, + ) + + # Verify tokenizer.save_pretrained was called + mock_tokenizer.save_pretrained.assert_called_once_with(tokenizer_path) + + +@pytest.fixture +def mock_experiment(): + """Create a real model, optimizer, and scheduler for integration testing.""" + model = TestModel() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + return model, optimizer, scheduler + + +def check_dict_equality(dict1, dict2): + """Recursively check equality of two dictionaries""" + for k in dict1.keys(): + if isinstance(dict1[k], dict): + check_dict_equality(dict1[k], dict2[k]) + elif isinstance(dict1[k], torch.Tensor): + assert torch.allclose(dict1[k], dict2[k]) + else: + assert dict1[k] == dict2[k] + + +class TestSaveLoadIntegration: + """Integration tests that actually save and load checkpoints.""" + + def test_save_and_load_model_only_safetensors(self, mock_experiment): + """Test saving and loading model weights only with safetensors format.""" + test_model, _, _ = mock_experiment + original_state_dict = test_model.state_dict() + + with TemporaryDirectory() as tmp_dir: + weights_path = os.path.join(tmp_dir, "test_model") + + # Save checkpoint + save_checkpoint( + model=test_model, + weights_path=weights_path, + model_save_format="safetensors", + ) + + # Verify files are created + assert os.path.exists(weights_path) + files = os.listdir(os.path.join(weights_path, "model")) + assert any(f.endswith(".safetensors") for f in files) + + # Create a new model with different weights + new_model = TestModel() + # Initialize with different values + for param in new_model.parameters(): + param.data.fill_(999.0) + + # Load the checkpoint + load_checkpoint(model=new_model, weights_path=weights_path) + + # Verify the weights match the original + check_dict_equality(new_model.state_dict(), original_state_dict) + + def test_save_and_load_model_only_torch_save(self, mock_experiment): + """Test saving and loading model weights only with torch_save format.""" + test_model, _, _ = mock_experiment + original_state_dict = test_model.state_dict() + + with TemporaryDirectory() as tmp_dir: + weights_path = os.path.join(tmp_dir, "test_model") + + # Save checkpoint + save_checkpoint( + model=test_model, + weights_path=weights_path, + model_save_format="torch_save", + ) + + # Verify files are created + assert os.path.exists(weights_path) + files = os.listdir(os.path.join(weights_path, "model")) + assert any(f.endswith(".distcp") for f in files) + + # Create a new model with different weights + new_model = TestModel() + # Initialize with different values + for param in new_model.parameters(): + param.data.fill_(999.0) + + # Load the checkpoint + load_checkpoint(model=new_model, weights_path=weights_path) + + # Verify the weights match the original + check_dict_equality(new_model.state_dict(), original_state_dict) + + def test_save_and_load_model_and_optimizer(self, mock_experiment): + """Test saving and loading both model and optimizer.""" + test_model, optimizer, scheduler = mock_experiment + + # Take some optimization steps to change optimizer state + for _ in range(5): + loss = torch.nn.functional.mse_loss( + test_model(torch.randn(2, 4)), torch.randn(2, 1) + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + original_model_state = test_model.state_dict() + original_optimizer_state = optimizer.state_dict() + original_scheduler_state = scheduler.state_dict() + + with TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "model_and_optimizer", "model") + optimizer_path = os.path.join(tmp_dir, "model_and_optimizer", "optimizer") + os.makedirs(os.path.dirname(model_path), exist_ok=True) + os.makedirs(os.path.dirname(optimizer_path), exist_ok=True) + + # Save checkpoint + save_checkpoint( + model=test_model, + weights_path=model_path, + optimizer=optimizer, + scheduler=scheduler, + optimizer_path=optimizer_path, + ) + + # Verify files are created + assert os.path.exists(model_path) + assert os.path.exists(optimizer_path) + + # Create new model, optimizer, and scheduler with different state + new_model = TestModel() + new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR( + new_optimizer, step_size=4, gamma=0.2 + ) + + # Initialize with different values + for param in new_model.parameters(): + param.data.fill_(999.0) + + # Load the checkpoint + load_checkpoint( + model=new_model, + weights_path=model_path, + optimizer=new_optimizer, + scheduler=new_scheduler, + optimizer_path=optimizer_path, + ) + + # Verify all states match the original + check_dict_equality(new_model.state_dict(), original_model_state) + check_dict_equality(new_optimizer.state_dict(), original_optimizer_state) + assert new_scheduler.state_dict() == original_scheduler_state From 7ee451bd8649fc1a2e98c17ebfd3d0f29ff95d08 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Sun, 14 Sep 2025 04:39:56 +0000 Subject: [PATCH 02/14] Update Automodel submodule. Signed-off-by: Felipe Vieira Frujeri --- 3rdparty/Automodel-workspace/Automodel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/Automodel-workspace/Automodel b/3rdparty/Automodel-workspace/Automodel index 71162c284d..7b55cabc0a 160000 --- a/3rdparty/Automodel-workspace/Automodel +++ b/3rdparty/Automodel-workspace/Automodel @@ -1 +1 @@ -Subproject commit 71162c284d315193cbb4011081228da2ba943c27 +Subproject commit 7b55cabc0a3b1d8b03b6c1f680c030ea2c8eaa77 From fafca8b95e874ca14b1b6917e741e2377fccc650 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Fri, 12 Sep 2025 11:40:49 -0700 Subject: [PATCH 03/14] Fix automodel api calling with model_state_dict_keys. Signed-off-by: Felipe Vieira Frujeri --- .../models/policy/dtensor_policy_worker_v2.py | 9 +++++ nemo_rl/utils/automodel_checkpoint.py | 30 +++++++++-------- uv.lock | 33 +++++++++++++++---- 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index f76f9e30ec..dd89df9111 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -214,6 +214,7 @@ def __init__( model_class = resolve_model_class(model_config.model_type) full_state_dict = None + model_state_dict_keys = None if self.rank == 0: print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") model = model_class.from_pretrained( @@ -225,6 +226,8 @@ def __init__( ) full_state_dict = model.state_dict() + # Store the original model state dict keys before any parallelization + model_state_dict_keys = list(full_state_dict.keys()) del model print(f"[Rank {self.rank}] Initializing empty model for FSDP...") @@ -350,6 +353,11 @@ def __init__( ), ) + # Broadcast model state dict keys to all ranks and store as instance variable + keys_to_broadcast = [model_state_dict_keys] + torch.distributed.broadcast_object_list(keys_to_broadcast, src=0) + self.model_state_dict_keys = keys_to_broadcast[0] + # Handle tied word embeddings after loading the state dict # We need to actually tie the parameters at the model level is_tied_lm_head = hasattr(self.model, "lm_head") and getattr( @@ -1466,6 +1474,7 @@ def save_checkpoint( optimizer_path=optimizer_path, tokenizer=self.tokenizer if tokenizer_path else None, tokenizer_path=tokenizer_path, + model_state_dict_keys=self.model_state_dict_keys, **checkpoint_kwargs, ) diff --git a/nemo_rl/utils/automodel_checkpoint.py b/nemo_rl/utils/automodel_checkpoint.py index ed1f743f47..98cc0c4ab3 100644 --- a/nemo_rl/utils/automodel_checkpoint.py +++ b/nemo_rl/utils/automodel_checkpoint.py @@ -18,6 +18,9 @@ from typing import Any, Optional import torch +from nemo_automodel.components.checkpoint._backports.filesystem import ( + SerializationFormat, +) # Apply torch backports for compatibility with torch==2.7.1 from nemo_automodel.components.checkpoint._torch_backports import apply_patches @@ -82,6 +85,7 @@ def save_checkpoint( is_peft: bool = False, peft_config: Optional[Any] = None, save_consolidated: bool = False, + model_state_dict_keys: Optional[list[str]] = None, ) -> None: """Save a checkpoint of the model and optionally optimizer state. @@ -96,9 +100,16 @@ def save_checkpoint( model_save_format: Format for saving model ("torch_save" or "safetensors") is_peft: Whether the model uses PEFT peft_config: PEFT configuration if is_peft is True + save_consolidated: Whether to save consolidated checkpoints (for HF compatibility) + model_state_dict_keys: Copy of the model state dict keys before any parallelization. + If None, will be extracted from the model's current state dict. """ # Create checkpoint config + # Extract model state dict keys if not provided + if model_state_dict_keys is None: + model_state_dict_keys = list(model.state_dict().keys()) + valid_formats = {"safetensors", "torch_save"} if model_save_format not in valid_formats: raise ValueError( @@ -121,6 +132,7 @@ def save_checkpoint( model_repo_id="", save_consolidated=save_consolidated, is_peft=is_peft, + model_state_dict_keys=model_state_dict_keys, ) # Save model using nemo-automodel API @@ -166,28 +178,20 @@ def load_checkpoint( optimizer: Optional optimizer to load state into scheduler: Optional scheduler to load state into optimizer_path: Path to load optimizer state from (required if optimizer provided) - model_save_format: Format model was saved in ("torch_save" or "safetensors") - is_peft: Whether the model uses PEFT """ print(f"Loading weights from {weights_path}") model_save_format, is_peft = detect_checkpoint_format(weights_path) - checkpoint_config = CheckpointingConfig( - enabled=True, - checkpoint_dir=os.path.dirname(weights_path), - model_save_format=model_save_format, - model_cache_dir="", # Not used for basic loading - model_repo_id="", # Not used for basic loading - save_consolidated=False, # Keep original behavior - is_peft=is_peft, - ) try: + format_enum = SerializationFormat[model_save_format.upper()] + # Load model using nemo-automodel API load_model( model=model, - weights_path=weights_path, - checkpoint_config=checkpoint_config, + model_path=weights_path, + model_save_format=format_enum, + is_peft=is_peft, ) except FileNotFoundError as e: msg = ( diff --git a/uv.lock b/uv.lock index efeb081ab4..05376b4928 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", @@ -1972,15 +1972,15 @@ wheels = [ [[package]] name = "liger-kernel" -version = "0.5.8" +version = "0.6.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "torch", version = "2.7.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, { name = "triton", version = "3.3.1", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a2/55/3a703f337110e2a121a04e503abfeec2c191529cbee18bb1fb630d65642a/liger_kernel-0.5.8.tar.gz", hash = "sha256:3246d7dced89e0f982a52de259d4f78fd10eb9171246b28ae52b63ad09fc0732", size = 3593097, upload-time = "2025-04-12T16:44:32.252Z" } +sdist = { url = "https://files.pythonhosted.org/packages/31/23/be0b4dcac42d77f99406c906567cde22a7a3d71b3f3ffdfda2ac6153ec36/liger_kernel-0.6.2.tar.gz", hash = "sha256:5c5bcffffa769bc26ae838f5a4954170dd5cacde036abb1b383039f39fa5fd69", size = 3679495, upload-time = "2025-08-22T00:15:28.456Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/40/75d82d90062b60e2aedd0b1741fe5b3dfbfd250aedd25933ef0b177b640e/liger_kernel-0.5.8-py3-none-any.whl", hash = "sha256:3102f99f89e9b9da249c83ea3f12b68680a8e8df0e477d4513e232da9af7d1a0", size = 150758, upload-time = "2025-04-12T16:44:30.791Z" }, + { url = "https://files.pythonhosted.org/packages/94/2c/68d992835e8630c1b94cdcb246ea7eecad790a955037ca3f19b6c01e8215/liger_kernel-0.6.2-py3-none-any.whl", hash = "sha256:303b9bbf5c10f9289c3139afb41e4d989e8c809516624a106b89b064163d971d", size = 192815, upload-time = "2025-08-22T00:15:27.04Z" }, ] [[package]] @@ -2348,6 +2348,21 @@ requires-dist = [ { name = "zarr" }, ] +[[package]] +name = "megatron-fsdp" +version = "0.1.0rc1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "torch", version = "2.7.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.7.1+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform != 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a0/be/06ada3d765ebca304e2d87873d6cf00807b43155ed57058abcd813d13a5d/megatron_fsdp-0.1.0rc1.tar.gz", hash = "sha256:4852a1c62bb95b5fc9567165ee7119f2e68bc75d6103af06bd1e6d392a50021f", size = 71600, upload-time = "2025-09-02T21:29:10.757Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/27/26ac0642311ef4690b70718cf482c2b83ea91770cb73056c7aa1f06f8857/megatron_fsdp-0.1.0rc1-py3-none-any.whl", hash = "sha256:c790b31b34de278e2c0fb07aa9eaa7edbdd55492005e857c55bee1450ffd03c9", size = 75936, upload-time = "2025-09-08T04:17:06.049Z" }, +] + [[package]] name = "mistral-common" version = "1.8.4" @@ -2728,6 +2743,7 @@ dependencies = [ { name = "bitsandbytes", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "datasets" }, { name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "megatron-fsdp" }, { name = "pybind11" }, { name = "pyyaml" }, { name = "torch", version = "2.7.1", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, @@ -2742,6 +2758,9 @@ dependencies = [ fa = [ { name = "flash-attn" }, ] +moe = [ + { name = "transformer-engine", extra = ["pytorch"] }, +] vlm = [ { name = "backoff" }, { name = "mistral-common", extra = ["opencv"] }, @@ -2788,7 +2807,8 @@ requires-dist = [ { name = "bitsandbytes", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = "==0.45.5" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "flash-attn", marker = "extra == 'fa'", specifier = "<=2.8.3" }, - { name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = "==0.5.8" }, + { name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = ">=0.5.9" }, + { name = "megatron-fsdp" }, { name = "mistral-common", extras = ["opencv"], marker = "extra == 'vlm'" }, { name = "numba", marker = "extra == 'vlm'" }, { name = "numpy", marker = "extra == 'vlm'" }, @@ -2802,11 +2822,12 @@ requires-dist = [ { name = "torchao" }, { name = "torchcodec", marker = "extra == 'vlm'" }, { name = "torchdata" }, + { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'moe'", specifier = "==2.5.0" }, { name = "transformers", specifier = "<=4.55.4" }, { name = "transformers", marker = "extra == 'vlm'", specifier = "<=4.55.4" }, { name = "wandb" }, ] -provides-extras = ["vlm", "fa"] +provides-extras = ["vlm", "fa", "moe"] [package.metadata.requires-dev] build = [ From 4f024d997d77793068d4ff55782508a91fca229b Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Mon, 15 Sep 2025 14:21:19 -0700 Subject: [PATCH 04/14] Disable liger_kernels explicitly. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/models/policy/dtensor_policy_worker_v2.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index dd89df9111..b3f0aa6dc2 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -222,6 +222,8 @@ def __init__( device_map="cpu", # load weights onto CPU initially trust_remote_code=True, config=model_config, + attn_implementation=model_config._attn_implementation, + use_liger_kernel=False, torch_dtype=str(model_config.torch_dtype), ) @@ -235,14 +237,10 @@ def __init__( # The actual weights will be broadcast from rank 0. with init_empty_weights(): - # NeMoAutoModelForCausalLM uses flash_attention_2 by default - # so we need to set it to None if sequence packing is disabled - # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 self.model = model_class.from_config( model_config, - attn_implementation="flash_attention_2" - if self.enable_seq_packing - else None, + attn_implementation=model_config._attn_implementation, + use_liger_kernel=False, trust_remote_code=True, torch_dtype=str(model_config.torch_dtype), ) From 06ed8501ba41ed629e4c9bd58fada28f0a10b26b Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Mon, 15 Sep 2025 15:53:16 -0700 Subject: [PATCH 05/14] Update tests to suport automodel markers. Signed-off-by: Felipe Vieira Frujeri --- docker/Dockerfile | 1 + pyproject.toml | 1 + tests/functional/L1_Functional_Tests_GPU.sh | 1 + ...est_automodel_extra_installed_correctly.sh | 72 +++++++++++++++++++ tests/unit/L0_Unit_Tests_Generation.sh | 10 ++- tests/unit/L0_Unit_Tests_Other.sh | 10 ++- tests/unit/L0_Unit_Tests_Policy.sh | 10 ++- tests/unit/conftest.py | 27 ++++++- tests/unit/utils/test_automodel_checkpoint.py | 12 ++++ 9 files changed, 138 insertions(+), 6 deletions(-) create mode 100755 tests/functional/test_automodel_extra_installed_correctly.sh diff --git a/docker/Dockerfile b/docker/Dockerfile index b12e1b929f..8f125b2ca7 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -89,6 +89,7 @@ RUN <<"EOF" bash -exu uv sync --link-mode symlink --locked --no-install-project uv sync --link-mode symlink --locked --extra vllm --no-install-project uv sync --link-mode symlink --locked --extra mcore --no-install-project +uv sync --link-mode symlink --locked --extra automodel --no-install-project uv sync --link-mode symlink --locked --all-groups --no-install-project EOF diff --git a/pyproject.toml b/pyproject.toml index 5fa43153c5..9dd0884f2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -213,6 +213,7 @@ markers = [ "run_first: marks tests that should run before others", "mcore: marks tests that require the mcore extra", "hf_gated: marks tests that require HuggingFace token access for gated models", + "automodel: marks tests that require the automodel extra", ] [tool.pyrefly] diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 9c1e1a86af..8c23fc8c7e 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -26,6 +26,7 @@ time uv run --no-sync bash ./tests/functional/rm.sh time uv run --no-sync bash ./tests/functional/eval.sh time uv run --no-sync bash ./tests/functional/eval_async.sh time uv run --no-sync bash ./tests/functional/test_mcore_extra_installed_correctly.sh +time uv run --no-sync bash ./tests/functional/test_automodel_extra_installed_correctly.sh time uv run --no-sync bash ./tests/functional/vlm_grpo.sh cd /opt/nemo-rl/tests diff --git a/tests/functional/test_automodel_extra_installed_correctly.sh b/tests/functional/test_automodel_extra_installed_correctly.sh new file mode 100755 index 0000000000..81b1ff124a --- /dev/null +++ b/tests/functional/test_automodel_extra_installed_correctly.sh @@ -0,0 +1,72 @@ +#!/bin/bash +set -eoux pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +cd $SCRIPT_DIR + +uv sync +# Just the first call with --extra automodel is invoked with --reinstall in case submodules were recently updated/downloaded +uv run --reinstall --extra automodel --no-build-isolation python <<"EOF" +import torch +import transformers +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +# Test basic transformers functionality that automodel extends +config = AutoConfig.from_pretrained("microsoft/DialoGPT-small") +print(f"Loaded config: {config.model_type}") + +# Test nemo_automodel import +try: + import nemo_automodel + from nemo_automodel.components._transformers.auto_model import NeMoAutoModelForCausalLM + print("[NeMo Automodel import successful]") +except ImportError as e: + print(f"[WARNING] NeMo Automodel import failed: {e}") + print("[This may be expected if nemo_automodel is not fully built]") + +# Test flash-attn import (part of automodel extra) +try: + import flash_attn + print(f"[Flash Attention available: {flash_attn.__version__}]") +except ImportError: + print("[WARNING] Flash Attention not available") + +# Test vllm import (part of automodel extra) +try: + import vllm + print(f"[vLLM available: {vllm.__version__}]") +except ImportError: + print("[WARNING] vLLM not available") + +print("[Automodel extra dependencies test successful]") +EOF + +# Test that automodel components can be accessed +uv run --extra automodel --no-build-isolation python <<"EOF" +# This must be the first import to get all of the automodel packages added to the path +import nemo_rl + +# Test automodel utilities +try: + from nemo_rl.utils.automodel_checkpoint import detect_checkpoint_format, load_checkpoint, save_checkpoint + print("[Automodel checkpoint utilities import successful]") +except ImportError as e: + print(f"[Automodel checkpoint utilities import failed: {e}]") + +# Test automodel factory +try: + from nemo_rl.models.policy.utils import AUTOMODEL_FACTORY, NEMO_AUTOMODEL_AVAILABLE + print(f"[Automodel factory available: {NEMO_AUTOMODEL_AVAILABLE}]") +except ImportError as e: + print(f"[Automodel factory import failed: {e}]") + +print("[Automodel integration test successful]") +EOF + +# Sync just to return the environment to the original base state +uv sync --link-mode symlink --locked --no-install-project +uv sync --link-mode symlink --locked --extra vllm --no-install-project +uv sync --link-mode symlink --locked --extra mcore --no-install-project +uv sync --link-mode symlink --locked --extra automodel --no-install-project +uv sync --link-mode symlink --locked --all-groups --no-install-project +echo Success diff --git a/tests/unit/L0_Unit_Tests_Generation.sh b/tests/unit/L0_Unit_Tests_Generation.sh index 3f607cc080..4902753858 100644 --- a/tests/unit/L0_Unit_Tests_Generation.sh +++ b/tests/unit/L0_Unit_Tests_Generation.sh @@ -20,10 +20,18 @@ uv run tests/unit/prepare_unit_test_assets.py cd /opt/nemo-rl uv run --no-sync bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated +# Check and run mcore tests exit_code=$(pytest tests/unit/models/generation/ --collect-only --hf-gated --mcore-only -q >/dev/null 2>&1; echo $?) if [[ $exit_code -eq 5 ]]; then echo "No mcore tests to run" - exit 0 else uv run --extra mcore bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --mcore-only fi + +# Check and run automodel tests +exit_code=$(pytest tests/unit/models/generation/ --collect-only --hf-gated --automodel -q >/dev/null 2>&1; echo $?) +if [[ $exit_code -eq 5 ]]; then + echo "No automodel tests to run" +else + uv run --extra automodel bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel +fi diff --git a/tests/unit/L0_Unit_Tests_Other.sh b/tests/unit/L0_Unit_Tests_Other.sh index a639730044..3421c4aaa1 100644 --- a/tests/unit/L0_Unit_Tests_Other.sh +++ b/tests/unit/L0_Unit_Tests_Other.sh @@ -20,10 +20,18 @@ uv run tests/unit/prepare_unit_test_assets.py cd /opt/nemo-rl uv run --no-sync bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated +# Check and run mcore tests exit_code=$(pytest tests/unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --collect-only --hf-gated --mcore-only -q >/dev/null 2>&1; echo $?) if [[ $exit_code -eq 5 ]]; then echo "No mcore tests to run" - exit 0 else uv run --extra mcore bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --mcore-only fi + +# Check and run automodel tests +exit_code=$(pytest tests/unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --collect-only --hf-gated --automodel -q >/dev/null 2>&1; echo $?) +if [[ $exit_code -eq 5 ]]; then + echo "No automodel tests to run" +else + uv run --extra automodel bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel +fi \ No newline at end of file diff --git a/tests/unit/L0_Unit_Tests_Policy.sh b/tests/unit/L0_Unit_Tests_Policy.sh index 6fe9309fe6..adb099f03d 100644 --- a/tests/unit/L0_Unit_Tests_Policy.sh +++ b/tests/unit/L0_Unit_Tests_Policy.sh @@ -20,10 +20,18 @@ uv run tests/unit/prepare_unit_test_assets.py cd /opt/nemo-rl uv run --no-sync bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated +# Check and run mcore tests exit_code=$(pytest tests/unit/models/policy/ --collect-only --hf-gated --mcore-only -q >/dev/null 2>&1; echo $?) if [[ $exit_code -eq 5 ]]; then echo "No mcore tests to run" - exit 0 else uv run --extra mcore bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --mcore-only fi + +# Check and run automodel tests +exit_code=$(pytest tests/unit/models/policy/ --collect-only --hf-gated --automodel -q >/dev/null 2>&1; echo $?) +if [[ $exit_code -eq 5 ]]; then + echo "No automodel tests to run" +else + uv run --extra automodel bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel +fi diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1f7dad4a2b..afd6dfecd7 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -45,12 +45,19 @@ def pytest_addoption(parser): default=False, help="Run ONLY mcore tests (combine with --hf-gated to include mcore+hf_gated tests)", ) + parser.addoption( + "--automodel", + action="store_true", + default=False, + help="Include tests that require the automodel extra", + ) def pytest_collection_modifyitems(config, items): """Modify test collection to skip tests based on markers unless explicitly requested.""" run_hf_gated = config.getoption("--hf-gated") run_mcore_only = config.getoption("--mcore-only") + run_automodel = config.getoption("--automodel") marker_expr = config.getoption("-m", default="") # If user specified -m marker expressions, still prioritize run_first tests @@ -73,16 +80,30 @@ def pytest_collection_modifyitems(config, items): and not item.get_closest_marker("hf_gated") ] elif run_hf_gated: - # Configuration 2: Default tests + hf_gated tests, excluding mcore - new_items = [item for item in items if not item.get_closest_marker("mcore")] + # Configuration 2: Default tests + hf_gated tests, excluding mcore and automodel + new_items = [ + item + for item in items + if not item.get_closest_marker("mcore") + and not item.get_closest_marker("automodel") + ] else: - # Configuration 1: Default only - exclude both hf_gated and mcore + # Configuration 1: Default only - exclude hf_gated, mcore, and automodel new_items = [ item for item in items if not item.get_closest_marker("hf_gated") and not item.get_closest_marker("mcore") + and not item.get_closest_marker("automodel") + ] + + # Add automodel tests if explicitly requested + if run_automodel: + automodel_items = [ + item for item in items if item.get_closest_marker("automodel") ] + # Remove duplicates by converting to set and back + new_items = list(set(new_items + automodel_items)) # Ensure run_first tests are prioritized new_items.sort(key=lambda item: 0 if item.get_closest_marker("run_first") else 1) diff --git a/tests/unit/utils/test_automodel_checkpoint.py b/tests/unit/utils/test_automodel_checkpoint.py index 15e5956343..6cd73aea01 100644 --- a/tests/unit/utils/test_automodel_checkpoint.py +++ b/tests/unit/utils/test_automodel_checkpoint.py @@ -19,6 +19,13 @@ import pytest import torch +# Skip entire module if nemo_automodel is not available +pytest_plugins = [] +try: + import nemo_automodel # noqa: F401 +except ImportError: + pytest.skip("nemo_automodel not available", allow_module_level=True) + from nemo_rl.utils.automodel_checkpoint import ( detect_checkpoint_format, load_checkpoint, @@ -59,6 +66,7 @@ def mock_optimizer(): return torch.optim.Adam(model.parameters()) +@pytest.mark.automodel class TestDetectCheckpointFormat: """Test the detect_checkpoint_format function.""" @@ -172,6 +180,7 @@ def test_expected_structure(self): """Test the save_checkpoint function.""" + @pytest.mark.automodel @patch("nemo_rl.utils.automodel_checkpoint.save_model") @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") def test_save_model_only(self, mock_save_optimizer, mock_save_model, mock_model): @@ -202,6 +211,7 @@ def test_save_model_only(self, mock_save_optimizer, mock_save_model, mock_model) # Verify optimizer saving was not called mock_save_optimizer.assert_not_called() + @pytest.mark.automodel @patch("nemo_rl.utils.automodel_checkpoint.save_model") @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") def test_save_with_optimizer( @@ -234,6 +244,7 @@ def test_save_with_optimizer( assert opt_call_args[1]["model"] is mock_model assert opt_call_args[1]["weights_path"] == optimizer_path + @pytest.mark.automodel @patch("nemo_rl.utils.automodel_checkpoint.save_model") def test_save_with_tokenizer(self, mock_save_model, mock_model): """Test saving with tokenizer.""" @@ -278,6 +289,7 @@ def check_dict_equality(dict1, dict2): assert dict1[k] == dict2[k] +@pytest.mark.automodel class TestSaveLoadIntegration: """Integration tests that actually save and load checkpoints.""" From fd85e319d64cbe94c51a0b87996a427367e92914 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Tue, 16 Sep 2025 00:19:33 -0700 Subject: [PATCH 06/14] Update tests/unit/conftest.py Co-authored-by: Terry Kong Signed-off-by: Felipe Vieira Frujeri --- tests/unit/conftest.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index afd6dfecd7..3f6d8ccf73 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -57,7 +57,8 @@ def pytest_collection_modifyitems(config, items): """Modify test collection to skip tests based on markers unless explicitly requested.""" run_hf_gated = config.getoption("--hf-gated") run_mcore_only = config.getoption("--mcore-only") - run_automodel = config.getoption("--automodel") + run_automodel_only = config.getoption("--automodel-only") + assert run_mcore_only ^ run_automodel_only, f"--mcore-only and --automodel-only are mutually exclusive markers" marker_expr = config.getoption("-m", default="") # If user specified -m marker expressions, still prioritize run_first tests From 3a5260f4c5747d7314eb4b7ac5403829b63e4144 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Tue, 16 Sep 2025 07:27:08 +0000 Subject: [PATCH 07/14] Clean up test filtering logic. Signed-off-by: Felipe Vieira Frujeri --- tests/unit/conftest.py | 66 ++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 41 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3f6d8ccf73..27fd6c1e45 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -46,10 +46,10 @@ def pytest_addoption(parser): help="Run ONLY mcore tests (combine with --hf-gated to include mcore+hf_gated tests)", ) parser.addoption( - "--automodel", + "--automodel-only", action="store_true", default=False, - help="Include tests that require the automodel extra", + help="Run ONLY automodel tests", ) @@ -58,7 +58,7 @@ def pytest_collection_modifyitems(config, items): run_hf_gated = config.getoption("--hf-gated") run_mcore_only = config.getoption("--mcore-only") run_automodel_only = config.getoption("--automodel-only") - assert run_mcore_only ^ run_automodel_only, f"--mcore-only and --automodel-only are mutually exclusive markers" + assert not (run_mcore_only and run_automodel_only), "--mcore-only and --automodel-only are mutually exclusive" marker_expr = config.getoption("-m", default="") # If user specified -m marker expressions, still prioritize run_first tests @@ -66,45 +66,29 @@ def pytest_collection_modifyitems(config, items): items.sort(key=lambda item: 0 if item.get_closest_marker("run_first") else 1) return - # Filter tests based on the desired configurations - new_items = [] - - if run_mcore_only and run_hf_gated: - # Configuration 4: Only mcore tests, including ones with hf_gated - new_items = [item for item in items if item.get_closest_marker("mcore")] - elif run_mcore_only: - # Configuration 3: Only mcore tests, excluding ones with hf_gated - new_items = [ - item - for item in items - if item.get_closest_marker("mcore") - and not item.get_closest_marker("hf_gated") - ] - elif run_hf_gated: - # Configuration 2: Default tests + hf_gated tests, excluding mcore and automodel - new_items = [ - item - for item in items - if not item.get_closest_marker("mcore") - and not item.get_closest_marker("automodel") - ] + # Start with all items and apply filters sequentially + new_items = list(items) + + # Filter by hf_gated marker + if not run_hf_gated: + # Exclude hf_gated tests unless explicitly requested + new_items = [item for item in new_items if not item.get_closest_marker("hf_gated")] + + # Filter by mcore marker + if run_mcore_only: + # Include only mcore tests + new_items = [item for item in new_items if item.get_closest_marker("mcore")] + else: + # Exclude mcore tests by default + new_items = [item for item in new_items if not item.get_closest_marker("mcore")] + + # Filter by automodel marker + if run_automodel_only: + # Include only automodel tests + new_items = [item for item in items if item.get_closest_marker("automodel")] else: - # Configuration 1: Default only - exclude hf_gated, mcore, and automodel - new_items = [ - item - for item in items - if not item.get_closest_marker("hf_gated") - and not item.get_closest_marker("mcore") - and not item.get_closest_marker("automodel") - ] - - # Add automodel tests if explicitly requested - if run_automodel: - automodel_items = [ - item for item in items if item.get_closest_marker("automodel") - ] - # Remove duplicates by converting to set and back - new_items = list(set(new_items + automodel_items)) + # Exclude automodel tests by default + new_items = [item for item in new_items if not item.get_closest_marker("automodel")] # Ensure run_first tests are prioritized new_items.sort(key=lambda item: 0 if item.get_closest_marker("run_first") else 1) From f4aab6b554ea345a9adc451c8670f420dd64d842 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Tue, 16 Sep 2025 07:46:55 +0000 Subject: [PATCH 08/14] Revert attn_implementation argument to Automodel. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/models/policy/dtensor_policy_worker_v2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index b3f0aa6dc2..d06132522f 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -222,7 +222,6 @@ def __init__( device_map="cpu", # load weights onto CPU initially trust_remote_code=True, config=model_config, - attn_implementation=model_config._attn_implementation, use_liger_kernel=False, torch_dtype=str(model_config.torch_dtype), ) @@ -237,9 +236,14 @@ def __init__( # The actual weights will be broadcast from rank 0. with init_empty_weights(): + # NeMoAutoModelForCausalLM uses flash_attention_2 by default + # so we need to set it to None if sequence packing is disabled + # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 self.model = model_class.from_config( model_config, - attn_implementation=model_config._attn_implementation, + attn_implementation="flash_attention_2" + if self.enable_seq_packing + else None, use_liger_kernel=False, trust_remote_code=True, torch_dtype=str(model_config.torch_dtype), From 8f043afe1a4fa71195bf6a4d29b715d2e05c6e20 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Tue, 16 Sep 2025 07:50:36 +0000 Subject: [PATCH 09/14] Fix linting on conftest.py. Signed-off-by: Felipe Vieira Frujeri --- tests/unit/conftest.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 27fd6c1e45..5fb97193bc 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -58,7 +58,9 @@ def pytest_collection_modifyitems(config, items): run_hf_gated = config.getoption("--hf-gated") run_mcore_only = config.getoption("--mcore-only") run_automodel_only = config.getoption("--automodel-only") - assert not (run_mcore_only and run_automodel_only), "--mcore-only and --automodel-only are mutually exclusive" + assert not (run_mcore_only and run_automodel_only), ( + "--mcore-only and --automodel-only are mutually exclusive" + ) marker_expr = config.getoption("-m", default="") # If user specified -m marker expressions, still prioritize run_first tests @@ -72,7 +74,9 @@ def pytest_collection_modifyitems(config, items): # Filter by hf_gated marker if not run_hf_gated: # Exclude hf_gated tests unless explicitly requested - new_items = [item for item in new_items if not item.get_closest_marker("hf_gated")] + new_items = [ + item for item in new_items if not item.get_closest_marker("hf_gated") + ] # Filter by mcore marker if run_mcore_only: @@ -88,7 +92,9 @@ def pytest_collection_modifyitems(config, items): new_items = [item for item in items if item.get_closest_marker("automodel")] else: # Exclude automodel tests by default - new_items = [item for item in new_items if not item.get_closest_marker("automodel")] + new_items = [ + item for item in new_items if not item.get_closest_marker("automodel") + ] # Ensure run_first tests are prioritized new_items.sort(key=lambda item: 0 if item.get_closest_marker("run_first") else 1) From 1f391deb1f4ede8bdfb5de9e754685b0a495492b Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Tue, 16 Sep 2025 22:21:59 +0000 Subject: [PATCH 10/14] Fix pytest command for automodel-only test markers. Signed-off-by: Felipe Vieira Frujeri --- tests/unit/L0_Unit_Tests_Generation.sh | 4 ++-- tests/unit/L0_Unit_Tests_Other.sh | 4 ++-- tests/unit/L0_Unit_Tests_Policy.sh | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/L0_Unit_Tests_Generation.sh b/tests/unit/L0_Unit_Tests_Generation.sh index 4902753858..5cf9da41b9 100644 --- a/tests/unit/L0_Unit_Tests_Generation.sh +++ b/tests/unit/L0_Unit_Tests_Generation.sh @@ -29,9 +29,9 @@ else fi # Check and run automodel tests -exit_code=$(pytest tests/unit/models/generation/ --collect-only --hf-gated --automodel -q >/dev/null 2>&1; echo $?) +exit_code=$(pytest tests/unit/models/generation/ --collect-only --hf-gated --automodel-only -q >/dev/null 2>&1; echo $?) if [[ $exit_code -eq 5 ]]; then echo "No automodel tests to run" else - uv run --extra automodel bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel + uv run --extra automodel bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel-only fi diff --git a/tests/unit/L0_Unit_Tests_Other.sh b/tests/unit/L0_Unit_Tests_Other.sh index 3421c4aaa1..815c45561f 100644 --- a/tests/unit/L0_Unit_Tests_Other.sh +++ b/tests/unit/L0_Unit_Tests_Other.sh @@ -29,9 +29,9 @@ else fi # Check and run automodel tests -exit_code=$(pytest tests/unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --collect-only --hf-gated --automodel -q >/dev/null 2>&1; echo $?) +exit_code=$(pytest tests/unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --collect-only --hf-gated --automodel-only -q >/dev/null 2>&1; echo $?) if [[ $exit_code -eq 5 ]]; then echo "No automodel tests to run" else - uv run --extra automodel bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel + uv run --extra automodel bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel-only fi \ No newline at end of file diff --git a/tests/unit/L0_Unit_Tests_Policy.sh b/tests/unit/L0_Unit_Tests_Policy.sh index adb099f03d..a19b7bf3eb 100644 --- a/tests/unit/L0_Unit_Tests_Policy.sh +++ b/tests/unit/L0_Unit_Tests_Policy.sh @@ -29,9 +29,9 @@ else fi # Check and run automodel tests -exit_code=$(pytest tests/unit/models/policy/ --collect-only --hf-gated --automodel -q >/dev/null 2>&1; echo $?) +exit_code=$(pytest tests/unit/models/policy/ --collect-only --hf-gated --automodel-only -q >/dev/null 2>&1; echo $?) if [[ $exit_code -eq 5 ]]; then echo "No automodel tests to run" else - uv run --extra automodel bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel + uv run --extra automodel bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-append --cov-report=term-missing --cov-report=json --hf-gated --automodel-only fi From ad33c57f67629cc5cb1a4cf0bee62603fc0d11b8 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Wed, 17 Sep 2025 00:58:22 +0000 Subject: [PATCH 11/14] Fix loading checkpoint path. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/utils/automodel_checkpoint.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nemo_rl/utils/automodel_checkpoint.py b/nemo_rl/utils/automodel_checkpoint.py index 98cc0c4ab3..89bfd35eb9 100644 --- a/nemo_rl/utils/automodel_checkpoint.py +++ b/nemo_rl/utils/automodel_checkpoint.py @@ -186,6 +186,11 @@ def load_checkpoint( try: format_enum = SerializationFormat[model_save_format.upper()] + # append /model to the weights_path if it doesn't exist + # TODO: remove this once nemo-automodel is updated + if not weights_path.endswith("/model"): + weights_path = os.path.join(weights_path, "model") + # Load model using nemo-automodel API load_model( model=model, From 46b32e0a0ee13ccadc70724a17a247ae619f511e Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Wed, 17 Sep 2025 02:05:47 +0000 Subject: [PATCH 12/14] Infer checkpoint dir. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/utils/automodel_checkpoint.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/nemo_rl/utils/automodel_checkpoint.py b/nemo_rl/utils/automodel_checkpoint.py index 89bfd35eb9..a9f0793851 100644 --- a/nemo_rl/utils/automodel_checkpoint.py +++ b/nemo_rl/utils/automodel_checkpoint.py @@ -38,6 +38,24 @@ apply_patches() +def _infer_checkpoint_root(weights_path: str) -> str: + """Infer checkpoint root directory from weights path. + + When weights_path ends with "…/weights/model", we need the parent of + the weights directory (the checkpoint root), not the weights directory itself. + + Args: + weights_path: Path to model weights (e.g., "/path/to/policy/weights/model") + + Returns: + str: Checkpoint root directory (e.g., "/path/to/policy") + """ + weights_dir = os.path.dirname(weights_path) + if weights_dir.endswith("weights"): + return os.path.dirname(weights_dir) + return weights_dir + + def detect_checkpoint_format(weights_path: str) -> tuple[str, bool]: """Detect model save format and PEFT status from checkpoint directory. @@ -126,7 +144,7 @@ def save_checkpoint( checkpoint_config = CheckpointingConfig( enabled=True, - checkpoint_dir=os.path.dirname(weights_path), + checkpoint_dir=_infer_checkpoint_root(weights_path), model_save_format=model_save_format, model_cache_dir="", model_repo_id="", From 5470a298cbfd4fd75a92585095b0e39ec70bd06d Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Wed, 17 Sep 2025 05:46:03 +0000 Subject: [PATCH 13/14] Update type hint for self.metric_name. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/utils/checkpoint.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index 0c1b68a54e..e2e314a88a 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -92,7 +92,7 @@ def __init__(self, config: CheckpointingConfig): config (CheckpointingConfig) """ self.checkpoint_dir = Path(config["checkpoint_dir"]) - self.metric_name = config["metric_name"] + self.metric_name: str | None = config["metric_name"] self.higher_is_better = config["higher_is_better"] self.keep_top_k = config["keep_top_k"] @@ -202,18 +202,17 @@ def remove_old_checkpoints(self, exclude_latest: bool = True) -> None: checkpoint_history.sort(key=lambda x: x[0], reverse=True) else: try: - metric_name = ( - self.metric_name - ) # Type checker hint - capture the non-None value # sort by metric value first, then by step number (for equal metrics, prefer more recent) if self.higher_is_better: # For higher_is_better=True: higher metric values first, then higher step numbers checkpoint_history.sort( - key=lambda x: (x[2][metric_name], x[0]), reverse=True + key=lambda x: (x[2][self.metric_name], x[0]), reverse=True ) else: # For higher_is_better=False: lower metric values first, then higher step numbers for equal values - checkpoint_history.sort(key=lambda x: (x[2][metric_name], -x[0])) + checkpoint_history.sort( + key=lambda x: (x[2][self.metric_name], -x[0]) + ) except KeyError: warnings.warn( f"Metric {self.metric_name} not found in checkpoint history. Keeping most recent k checkpoints." @@ -250,9 +249,8 @@ def get_best_checkpoint_path(self) -> Optional[str]: ) return self.get_latest_checkpoint_path() - metric_name = self.metric_name # Type checker hint - capture the non-None value checkpoint_history.sort( - key=lambda x: x[2][metric_name], reverse=self.higher_is_better + key=lambda x: x[2][self.metric_name], reverse=self.higher_is_better ) return str(checkpoint_history[0][1]) From 87ec98046032629fe105011af7f6f9217d9667bb Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Wed, 17 Sep 2025 22:15:57 +0000 Subject: [PATCH 14/14] Revert grpo_math_1B config to original state. Signed-off-by: Felipe Vieira Frujeri --- examples/configs/grpo_math_1B.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 09b381bfff..18cabeb7f8 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -54,7 +54,7 @@ policy: cpu_offload: False sequence_parallel: false activation_checkpointing: false - tensor_parallel_size: 2 + tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null @@ -230,5 +230,5 @@ logger: flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: - gpus_per_node: 2 + gpus_per_node: 1 num_nodes: 1