Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7929d39
memory optimizations for Nemotron12B 12k seqlen DPO training
ybgao-nvidia Aug 14, 2025
02bee2a
implement suggested changes
ybgao-nvidia Aug 18, 2025
fb8c1bb
add copyright
ybgao-nvidia Aug 18, 2025
87f858b
make lint pass
ybgao-nvidia Aug 18, 2025
b637b03
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 18, 2025
00d5d9a
update configuration key and README
ybgao-nvidia Aug 19, 2025
9be4a2c
fix allocator setting
ybgao-nvidia Aug 19, 2025
63a82de
update readme and lint
ybgao-nvidia Aug 19, 2025
a2cdc5a
disable expandable segments by default
ybgao-nvidia Aug 20, 2025
1a38935
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 20, 2025
34124c8
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 20, 2025
d3c9ad7
Update README.md
ybgao-nvidia Aug 20, 2025
ae128cc
remove configure_expandable_segments
ybgao-nvidia Aug 20, 2025
caaa87f
Update README.md
ybgao-nvidia Aug 20, 2025
60e2909
fix config schema
ybgao-nvidia Aug 20, 2025
0b2164a
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 20, 2025
35918c7
add test script
ybgao-nvidia Aug 21, 2025
a94b4fb
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 21, 2025
e1d0447
will tests pass now?
ybgao-nvidia Aug 21, 2025
51a2607
make tests pass
ybgao-nvidia Aug 23, 2025
e7f2b26
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 24, 2025
28cae68
include field in logger config
ybgao-nvidia Aug 25, 2025
5562dac
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 25, 2025
38aca9c
remove expandable segments from v2
ybgao-nvidia Aug 25, 2025
189868b
please pass :(
ybgao-nvidia Aug 25, 2025
7573f6d
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 26, 2025
3271a08
empty cache
ybgao-nvidia Aug 26, 2025
b97abd2
Merge branch 'main' into ybgao/aug13-dpo-12k-memory
ybgao-nvidia Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,24 @@ For detailed instructions on how to set up and launch NeMo RL on Slurm or Kubern
NRL_FORCE_REBUILD_VENVS=true uv run examples/run_grpo.py ...
```

- Large amounts of memory fragmentation might occur when running models without support for FlashAttention2.
If OOM occurs after a few iterations of training, it may help to tweak the allocator settings to reduce memory fragmentation.
To do so, specify [`max_split_size_mb`](https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)
at **either** one of the following places:
1. Launch training with:
```sh
# This will globally apply to all ray actors
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:64 uv run python examples/run_dpo.py ...
```
2. Make the change more permanently by adding this flag in the training configuration:
```yaml
policy:
# ...
dtensor_cfg:
env_vars:
PYTORCH_CUDA_ALLOC_CONF: "max_split_size_mb:64"
```

## Citation

If you use NeMo RL in your research, please cite it using the following BibTeX entry:
Expand Down
4 changes: 4 additions & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ policy:
precision: "bfloat16"

dtensor_cfg:
env_vars:
PYTORCH_CUDA_ALLOC_CONF: "" # Refers to https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
enabled: true
cpu_offload: False
sequence_parallel: false
Expand Down Expand Up @@ -155,9 +157,11 @@ data:
logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running

tensorboard_enabled: false
mlflow_enabled: false # Disable MLflow logging
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb:
project: "dpo-dev"
name: "dpo"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# DPO Algorithm Configuration
dpo:
max_num_epochs: 1
max_num_steps: 100
val_period: 10
val_batches: 1
val_global_batch_size: 16
val_micro_batch_size: 1
val_at_start: true
seed: 42

reference_policy_kl_penalty: 0.1
preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss
sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss

preference_loss_weight: 1 # the coefficient of the preference loss
sft_loss_weight: 0 # the coefficient of the SFT loss

checkpointing:
enabled: true
checkpoint_dir: "results/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long"
metric_name: "val_loss"
higher_is_better: false
keep_top_k: null
save_period: 50
checkpoint_must_save_by: null

policy:
model_name: "mistralai/Mistral-Nemo-Instruct-2407"
tokenizer:
name: ${policy.model_name}

# number of preference samples per batch
# each preference sample corresponds to a pair of chosen and rejected responses
# so the actual batch size processed by the model is train_global_batch_size * 2
train_global_batch_size: 8
train_micro_batch_size: 1


#logprob_batch_size: ${policy.train_micro_batch_size}
max_total_sequence_length: 12288
precision: "bfloat16"

dtensor_cfg:
enabled: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: true
tensor_parallel_size: 8
context_parallel_size: 1
custom_parallel_plan: null
env_vars:
PYTORCH_CUDA_ALLOC_CONF: "max_split_size_mb:64"

dynamic_batching:
enabled: false

sequence_packing:
enabled: false

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
max_grad_norm: 1.0

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 1.0e-6
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
# when using Dtensor, we need to set foreach
# and fused to False
foreach: False
fused: False

scheduler:
- name: "torch.optim.lr_scheduler.ConstantLR"
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones: []

data:
dataset_name: "HelpSteer3"
shuffle: False
max_input_seq_length: ${policy.max_total_sequence_length}

logger:
log_dir: "logs/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long" # Base directory for all logs
wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running
tensorboard_enabled: false
mlflow_enabled: false
monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb:
project: "nemo-rl"
name: "dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
num_nodes: 1
19 changes: 19 additions & 0 deletions nemo_rl/models/dtensor/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,25 @@ def _parallelize_model(
for i in range(len(layers)):
layers[i].mlp = checkpoint_wrapper(layers[i].mlp) # type: ignore

"""
the extra memory overhead for layer norm seems to be only present
in mistral models, where some intermediate state is converted to float32

need to find a better solution for checkpointing
"""
if hasattr(layers[i], "self_attn"):
layers[i].self_attn = checkpoint_wrapper(layers[i].self_attn) # type: ignore

if hasattr(layers[i], "input_layernorm"):
layers[i].input_layernorm = checkpoint_wrapper(
layers[i].input_layernorm # type: ignore
)

if hasattr(layers[i], "post_attention_layernorm"):
layers[i].post_attention_layernorm = checkpoint_wrapper(
layers[i].post_attention_layernorm # type: ignore
)

mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
Expand Down
1 change: 1 addition & 0 deletions nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class DTensorConfig(TypedDict):
enabled: bool
env_vars: NotRequired[dict[str, str]]
_v2: NotRequired[bool]
cpu_offload: NotRequired[bool]
sequence_parallel: NotRequired[bool]
Expand Down
6 changes: 2 additions & 4 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
)
from nemo_rl.models.policy.utils import (
configure_dynamo_cache,
configure_expandable_segments,
get_gpu_info,
get_handle_from_tensor,
get_runtime_env_for_policy_worker,
Expand Down Expand Up @@ -173,9 +172,6 @@ def __init__(
# with different order of node_bundles
configure_dynamo_cache()

# Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+)
configure_expandable_segments()

# vars used for refit
## will be initialized in prepare_refit_info
self.refit_param_info = None
Expand Down Expand Up @@ -642,6 +638,8 @@ def train(
for mb_idx, mb in enumerate(
itertools.chain(mb_iterator, dummy_iterator)
):
torch.cuda.empty_cache()

with torch.autocast(device_type="cuda", dtype=self.dtype):
if self.enable_seq_packing:
input_ids = mb.get("input_ids").cuda()
Expand Down
6 changes: 2 additions & 4 deletions nemo_rl/models/policy/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
)
from nemo_rl.models.policy.utils import (
configure_dynamo_cache,
configure_expandable_segments,
get_gpu_info,
get_handle_from_tensor,
get_runtime_env_for_policy_worker,
Expand Down Expand Up @@ -126,9 +125,6 @@ def __init__(
# with different order of node_bundles
configure_dynamo_cache()

# Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+)
configure_expandable_segments()

self.cfg = config
# torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call
torch.distributed.init_process_group(backend="nccl")
Expand Down Expand Up @@ -570,6 +566,8 @@ def train(
for mb_idx, mb in enumerate(
itertools.chain(mb_iterator, dummy_iterator)
):
torch.cuda.empty_cache()

with torch.autocast(device_type="cuda", dtype=self.dtype):
if self.enable_seq_packing:
input_ids = mb.get("input_ids").cuda()
Expand Down
4 changes: 0 additions & 4 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@
)
from nemo_rl.models.policy.utils import (
configure_dynamo_cache,
configure_expandable_segments,
get_gpu_info,
get_handle_from_tensor,
get_megatron_checkpoint_dir,
Expand Down Expand Up @@ -410,9 +409,6 @@ def __init__(
# with different order of node_bundles
configure_dynamo_cache()

# Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+)
configure_expandable_segments()

# cfg["model_name"] is allowed to be either an HF model name or a path to an HF checkpoint
# check if hf_model_name is a path
hf_model_name = self.cfg["model_name"]
Expand Down
43 changes: 0 additions & 43 deletions nemo_rl/models/policy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,49 +165,6 @@ def sliding_window_overwrite(model_name: str) -> dict[str, Any]:
return overwrite_dict


def configure_expandable_segments() -> None:
"""Configure expandable_segments on Hopper and newer architectures (compute capability 9.x+).

This helps with memory allocation but causes crashes on Ampere GPUs, so we only enable it
on newer architectures. If PYTORCH_CUDA_ALLOC_CONF is already set, preserves existing values.
"""
compute_capability = torch.cuda.get_device_properties(0).major

if compute_capability >= 9: # Hopper+
existing_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")

# Check if expandable_segments is already configured
if "expandable_segments" in existing_conf:
print(f"expandable_segments already configured: {existing_conf}")
# Already configured, don't override
return

# Add expandable_segments to existing configuration
if existing_conf:
# Append to existing configuration
new_conf = f"{existing_conf},expandable_segments:True"
else:
# Set new configuration
new_conf = "expandable_segments:True"

print(f"Setting PYTORCH_CUDA_ALLOC_CONF to {new_conf}")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = new_conf

else:
## make sure that expandable_segments is not set to True
if "expandable_segments" in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""):
conf_items = os.environ["PYTORCH_CUDA_ALLOC_CONF"].split(",")
for item in conf_items:
if item.strip().startswith("expandable_segments"):
key_value = item.split(":")
if len(key_value) == 2 and key_value[1].strip().lower() == "true":
raise RuntimeError(
"expandable_segments is enabled in PYTORCH_CUDA_ALLOC_CONF, "
"but this is not supported on architectures older than Hopper (compute capability < 9). "
"Please set expandable_segments to False."
)


def configure_dynamo_cache() -> None:
"""Disable dynamo autotune_local_cache.

Expand Down
3 changes: 2 additions & 1 deletion nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ class LoggerConfig(TypedDict):
tensorboard_enabled: bool
mlflow_enabled: bool
wandb: WandbConfig
tensorboard: TensorboardConfig
tensorboard: NotRequired[TensorboardConfig]
mlflow: NotRequired[MLflowConfig]
monitor_gpus: bool
gpu_monitoring: GPUMonitoringConfig
num_val_samples_to_print: NotRequired[int]


class LoggerInterface(ABC):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=100
MAX_STEPS=100
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=45
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_dpo.py \
--config $CONFIG_PATH \
dpo.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["1"] > 0.6990' \
'data["train/loss"]["1"] < 0.6992' \
'data["train/loss"]["100"] < 0.60'
fi
3 changes: 3 additions & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.sh

# Short megatron
tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.sh

# Long dtensor
tests/test_suites/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.sh
Loading