Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions docs/design-docs/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,7 @@
## Checkpoint Format
NeMo-RL provides two checkpoint formats for HuggingFace models: Torch distributed and HuggingFace format. Torch distributed is used by default for efficiency, and HuggingFace format is provided for compatibility with HuggingFace's `AutoModel.from_pretrained` API. Note that HuggingFace format checkpoints save only the model weights, ignoring the optimizer states. It is recommended to use Torch distributed format to save intermediate checkpoints and to save a HuggingFace checkpoint only at the end of training.

There are two ways to get a NeMo-RL checkpoint in HuggingFace format.

1. (Recommended) Save the HuggingFace checkpoint directly by passing `save_hf=True` to `HFPolicy`'s `save_checkpoint`:

```python
policy.save_checkpoint(
weights_path=<WHERE_TO_SAVE_MODEL_WEIGHTS>,
optimizer_path=<WHERE_TO_SAVE_OPTIM_STATE>,
save_torch_dist=True,
save_hf=True,
)
```
2. Convert a Torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose.
A checkpoint converter is provided to convert a Torch distributed checkpoint checkpoint to HuggingFace format after training:

```python
uv run examples/convert_dcp_to_hf.py --config=<YAML CONFIG USED DURING TRAINING> <ANY CONFIG OVERRIDES USED DURING TRAINING> --dcp-ckpt-path=<PATH TO DIST CHECKPOINT TO CONVERT> --hf-ckpt-path=<WHERE TO SAVE HF CHECKPOINT>
Expand Down
12 changes: 3 additions & 9 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,14 +446,6 @@ def dpo_train(
% master_config["checkpointing"]["save_period"]
== 0
): # +1 because step is 0-indexed
is_last_checkpoint = (
min(
len(train_dataloader) * max_num_epochs,
master_config["dpo"]["max_num_steps"],
)
- (total_steps + 1)
< master_config["checkpointing"]["save_period"]
)
dpo_save_state["step"] = (current_step + 1) % len(train_dataloader)
dpo_save_state["total_steps"] = total_steps + 1
dpo_save_state["epoch"] = current_epoch
Expand All @@ -470,7 +462,9 @@ def dpo_train(
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
save_hf=is_last_checkpoint,
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
)
torch.save(
train_dataloader.state_dict(),
Expand Down
7 changes: 0 additions & 7 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,6 @@ def grpo_train(
): # +1 because step is 0-indexed
policy.prepare_for_training()

is_last_checkpoint = (
min(len(dataloader), master_config["grpo"]["max_num_steps"])
- (step + 1)
< master_config["checkpointing"]["save_period"]
)

grpo_save_state["step"] = step + 1
grpo_save_state["val_reward"] = val_metrics["accuracy"]
grpo_save_state["consumed_samples"] = consumed_samples
Expand All @@ -546,7 +540,6 @@ def grpo_train(
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
dataloader.state_dict(),
Expand Down
10 changes: 0 additions & 10 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,6 @@ def sft_train(
% master_config["checkpointing"]["save_period"]
== 0
): # +1 because step is 0-indexed
is_last_checkpoint = (
min(
len(train_dataloader) * max_num_epochs,
master_config["sft"]["max_num_steps"],
)
- (total_steps + 1)
< master_config["checkpointing"]["save_period"]
)

sft_save_state["step"] = (current_step + 1) % len(train_dataloader)
sft_save_state["total_steps"] = total_steps + 1
sft_save_state["epoch"] = current_epoch
Expand All @@ -476,7 +467,6 @@ def sft_train(
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
train_dataloader.state_dict(),
Expand Down
7 changes: 1 addition & 6 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,10 @@ def save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
save_torch_dist: bool = True,
save_hf: bool = False,
):
"""Save a checkpoint of the model.

the HuggingFace checkpoint is saved only if `save_hf` is True,
and the optimizer states are saved only if `optimizer` and `optimizer_path` are provided.
the optimizer states are saved only if `optimizer` and `optimizer_path` are provided.
"""
save_checkpoint(
model=self.model,
Expand All @@ -730,8 +727,6 @@ def save_checkpoint(
optimizer_path=optimizer_path,
tokenizer=self.tokenizer if tokenizer_path else None,
tokenizer_path=tokenizer_path,
save_torch_dist=save_torch_dist,
save_hf=save_hf,
)

def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None):
Expand Down
13 changes: 1 addition & 12 deletions nemo_rl/models/policy/fsdp1_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,6 @@ def save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
save_torch_dist: bool = True,
save_hf: bool = False,
):
"""Save a checkpoint of the model.

Expand All @@ -921,19 +919,12 @@ def save_checkpoint(
__0_1.distcp
__1_0.distcp
...
weights_path-hf/
config.json
generation_config.json
model-00001-of-<TOTAL_SHARDS>.safetensors
...
model.safetensors.index.json
optimizer_path/
__0_0.distcp
__1_0.distcp
...

the HuggingFace checkpoint is saved only if `save_hf` is True,
and the optimizer states are saved only if `optimizer` and `optimizer_path` are provided.
the optimizer states are saved only if `optimizer` and `optimizer_path` are provided.
"""
save_checkpoint(
model=self.model,
Expand All @@ -943,8 +934,6 @@ def save_checkpoint(
optimizer_path=optimizer_path,
tokenizer=self.tokenizer if tokenizer_path else None,
tokenizer_path=tokenizer_path,
save_torch_dist=save_torch_dist,
save_hf=save_hf,
)

def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None):
Expand Down
4 changes: 0 additions & 4 deletions nemo_rl/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,17 +307,13 @@ def save_checkpoint(
weights_path: str,
optimizer_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
save_torch_dist: bool = True,
save_hf: bool = False,
):
"""Save a checkpoint of the model."""
futures = self.worker_group.run_all_workers_single_data(
"save_checkpoint",
weights_path,
optimizer_path,
tokenizer_path,
save_torch_dist,
save_hf,
only_on="all_tied_workers",
)
ray.get(futures)
Expand Down
44 changes: 9 additions & 35 deletions nemo_rl/utils/native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Checkpoint management utilities for HF models."""

import os
from pathlib import Path
from typing import Any, Optional

import torch
Expand Down Expand Up @@ -139,8 +138,6 @@ def save_checkpoint(
optimizer_path: Optional[str] = None,
tokenizer: Optional[Any] = None,
tokenizer_path: Optional[str] = None,
save_torch_dist: bool = True,
save_hf: bool = False,
) -> None:
"""Save a checkpoint of the model and optionally optimizer state.

Expand All @@ -150,40 +147,17 @@ def save_checkpoint(
optimizer: Optional optimizer to save
scheduler: Optional scheduler to save
optimizer_path: Path to save optimizer state (required if optimizer provided)
save_torch_dist: Whether to save in PyTorch distributed format
save_hf: Whether to save in HuggingFace format
"""
if save_hf:
if hasattr(model, "_fsdp_wrapped_module"):
model_state_dict = model._fsdp_wrapped_module.state_dict()
else:
model_state_dict = {
k: v.full_tensor()
if isinstance(v, torch.distributed.tensor.DTensor)
else v
for k, v in model.state_dict().items()
}

if torch.distributed.get_rank() == 0:
# Create a new path by appending "-hf" to the weights path
hf_weights_path = f"{Path(weights_path)}-hf"

model.save_pretrained(
hf_weights_path,
state_dict=model_state_dict,
)
model_state = {"model": ModelState(model)}
dcp.save(model_state, checkpoint_id=weights_path)

if save_torch_dist:
model_state = {"model": ModelState(model)}
dcp.save(model_state, checkpoint_id=weights_path)

if optimizer is not None:
if optimizer_path is None:
raise ValueError(
"optimizer_path must be provided when saving optimizer state"
)
optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)}
dcp.save(optimizer_state, checkpoint_id=optimizer_path)
if optimizer is not None:
if optimizer_path is None:
raise ValueError(
"optimizer_path must be provided when saving optimizer state"
)
optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)}
dcp.save(optimizer_state, checkpoint_id=optimizer_path)

if tokenizer is not None:
if tokenizer_path is None:
Expand Down
101 changes: 1 addition & 100 deletions tests/unit/utils/test_native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,77 +283,6 @@ def test_save_and_load_model_and_optimizer(mock_experiment):
check_dict_equality(new_optimizer.state_dict(), optimizer.state_dict())


@pytest.mark.parametrize("num_gpus", [1, 2], ids=["1gpu", "2gpu"])
def test_save_and_load_hf_checkpoint(policy, num_gpus):
## warm up with a forward pass
## this is needed before saving a checkpoint because FSDP does some lazy initialization
input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128
attention_mask = torch.ones(4, 128)
input_lengths = attention_mask.sum(dim=1).to(torch.int32)
dummy_fwd_dict = BatchedDataDict(
{
"input_ids": input_ids,
"input_lengths": input_lengths,
"attention_mask": attention_mask,
"labels": torch.randint(0, 16000, (4, 128)),
}
)
policy.get_logprobs(dummy_fwd_dict)

with TemporaryDirectory() as tmp_dir:
policy.save_checkpoint(
os.path.join(tmp_dir, "test_hf_and_dcp"),
save_hf=True,
save_torch_dist=True,
tokenizer_path=os.path.join(tmp_dir, "test_hf_and_dcp_tokenizer"),
)

## make sure we save both HF and DCP checkpoints
# Dynamically create the expected set of distcp files based on num_gpus
expected_distcp_files = {f"__{rank}_0.distcp" for rank in range(num_gpus)}
expected_files = expected_distcp_files.union({".metadata"})

assert (
set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == expected_files
)
assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp_tokenizer"))) == {
"tokenizer_config.json",
"tokenizer.json",
"special_tokens_map.json",
}

converted_model = AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_dir, "test_hf_and_dcp-hf")
)

hf_save_dir = os.path.join(tmp_dir, "test_hf_and_dcp-hf")
hf_files = set(os.listdir(hf_save_dir))

# Check the HF saved files structure: could be single or sharded
expected_common_hf_files = {"config.json", "generation_config.json"}
if "model.safetensors" in hf_files:
# Single file format (1 GPU or smaller model)
expected_hf_files = expected_common_hf_files.union({"model.safetensors"})
else:
# Sharded format (>=2 GPUs or larger model)
expected_hf_files = expected_common_hf_files.union(
{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
}
)
assert hf_files == expected_hf_files

coverted_model = AutoModelForCausalLM.from_pretrained(hf_save_dir)
original_model = AutoModelForCausalLM.from_pretrained(
simple_policy_config["model_name"]
)

## make sure converted model matches the original
check_dict_equality(converted_model.state_dict(), original_model.state_dict())


@pytest.mark.parametrize("num_gpus", [1, 2], ids=["1gpu", "2gpu"])
def test_convert_dcp_to_hf(policy, num_gpus):
## warm up with a forward pass
Expand All @@ -374,8 +303,6 @@ def test_convert_dcp_to_hf(policy, num_gpus):
with TemporaryDirectory() as tmp_dir:
policy.save_checkpoint(
os.path.join(tmp_dir, "test_hf_and_dcp"),
save_hf=True,
save_torch_dist=True,
)

# Dynamically create the expected set of distcp files based on num_gpus
Expand All @@ -387,25 +314,6 @@ def test_convert_dcp_to_hf(policy, num_gpus):
set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == expected_files
)

# Check the HF saved files structure: could be single or sharded
hf_save_dir = os.path.join(tmp_dir, "test_hf_and_dcp-hf")
hf_files = set(os.listdir(hf_save_dir))
expected_common_hf_files = {"config.json", "generation_config.json"}

if "model.safetensors" in hf_files:
# Single file format (1 GPU or smaller model)
expected_hf_files = expected_common_hf_files.union({"model.safetensors"})
else:
# Sharded format (>=2 GPUs or larger model)
expected_hf_files = expected_common_hf_files.union(
{
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
"model.safetensors.index.json",
}
)
assert hf_files == expected_hf_files

offline_converted_model_path = convert_dcp_to_hf(
os.path.join(tmp_dir, "test_hf_and_dcp"),
os.path.join(tmp_dir, "test_hf_and_dcp-hf-offline"),
Expand All @@ -423,18 +331,11 @@ def test_convert_dcp_to_hf(policy, num_gpus):
offline_converted_model_path
)

online_converted_model = AutoModelForCausalLM.from_pretrained(
os.path.join(tmp_dir, "test_hf_and_dcp-hf")
)
original_model = AutoModelForCausalLM.from_pretrained(
simple_policy_config["model_name"]
)

## make sure both conversions results in the same state dict
check_dict_equality(
online_converted_model.state_dict(), offline_converted_model.state_dict()
)
# Ensure the offline one is different from the original
# Ensure the offline checkpoint is different from the original
assert_recursive_dict_different(
offline_converted_model.state_dict(), original_model.state_dict()
)