diff --git a/README.md b/README.md index 75bf02f4fe..77ec8274eb 100644 --- a/README.md +++ b/README.md @@ -473,6 +473,24 @@ For detailed instructions on how to set up and launch NeMo RL on Slurm or Kubern NRL_FORCE_REBUILD_VENVS=true uv run examples/run_grpo.py ... ``` +- Large amounts of memory fragmentation might occur when running models without support for FlashAttention2. + If OOM occurs after a few iterations of training, it may help to tweak the allocator settings to reduce memory fragmentation. + To do so, specify [`max_split_size_mb`](https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf) + at **either** one of the following places: + 1. Launch training with: + ```sh + # This will globally apply to all ray actors + PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:64 uv run python examples/run_dpo.py ... + ``` + 2. Make the change more permanently by adding this flag in the training configuration: + ```yaml + policy: + # ... + dtensor_cfg: + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "max_split_size_mb:64" + ``` + ## Citation If you use NeMo RL in your research, please cite it using the following BibTeX entry: diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 74a74efc21..cfe2c011e3 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -45,6 +45,8 @@ policy: precision: "bfloat16" dtensor_cfg: + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "" # Refers to https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf enabled: true cpu_offload: False sequence_parallel: false @@ -155,9 +157,11 @@ data: logger: log_dir: "logs" # Base directory for all logs wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false mlflow_enabled: false # Disable MLflow logging monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal wandb: project: "dpo-dev" name: "dpo" diff --git a/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml new file mode 100644 index 0000000000..084ea843f2 --- /dev/null +++ b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml @@ -0,0 +1,106 @@ +# DPO Algorithm Configuration +dpo: + max_num_epochs: 1 + max_num_steps: 100 + val_period: 10 + val_batches: 1 + val_global_batch_size: 16 + val_micro_batch_size: 1 + val_at_start: true + seed: 42 + + reference_policy_kl_penalty: 0.1 + preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss + sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss + + preference_loss_weight: 1 # the coefficient of the preference loss + sft_loss_weight: 0 # the coefficient of the SFT loss + +checkpointing: + enabled: true + checkpoint_dir: "results/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long" + metric_name: "val_loss" + higher_is_better: false + keep_top_k: null + save_period: 50 + checkpoint_must_save_by: null + +policy: + model_name: "mistralai/Mistral-Nemo-Instruct-2407" + tokenizer: + name: ${policy.model_name} + + # number of preference samples per batch + # each preference sample corresponds to a pair of chosen and rejected responses + # so the actual batch size processed by the model is train_global_batch_size * 2 + train_global_batch_size: 8 + train_micro_batch_size: 1 + + + #logprob_batch_size: ${policy.train_micro_batch_size} + max_total_sequence_length: 12288 + precision: "bfloat16" + + dtensor_cfg: + enabled: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 8 + context_parallel_size: 1 + custom_parallel_plan: null + env_vars: + PYTORCH_CUDA_ALLOC_CONF: "max_split_size_mb:64" + + dynamic_batching: + enabled: false + + sequence_packing: + enabled: false + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 1.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [] + +data: + dataset_name: "HelpSteer3" + shuffle: False + max_input_seq_length: ${policy.max_total_sequence_length} + +logger: + log_dir: "logs/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long" # Base directory for all logs + wandb_enabled: false # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb: + project: "nemo-rl" + name: "dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 59396dcf9e..e5469e3e04 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -661,6 +661,25 @@ def _parallelize_model( for i in range(len(layers)): layers[i].mlp = checkpoint_wrapper(layers[i].mlp) # type: ignore + """ + the extra memory overhead for layer norm seems to be only present + in mistral models, where some intermediate state is converted to float32 + + need to find a better solution for checkpointing + """ + if hasattr(layers[i], "self_attn"): + layers[i].self_attn = checkpoint_wrapper(layers[i].self_attn) # type: ignore + + if hasattr(layers[i], "input_layernorm"): + layers[i].input_layernorm = checkpoint_wrapper( + layers[i].input_layernorm # type: ignore + ) + + if hasattr(layers[i], "post_attention_layernorm"): + layers[i].post_attention_layernorm = checkpoint_wrapper( + layers[i].post_attention_layernorm # type: ignore + ) + mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=torch.float32, diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 228afd5f95..b909d89ed6 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -19,6 +19,7 @@ class DTensorConfig(TypedDict): enabled: bool + env_vars: NotRequired[dict[str, str]] _v2: NotRequired[bool] cpu_offload: NotRequired[bool] sequence_parallel: NotRequired[bool] diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 534e67a32d..060e0636d5 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -65,7 +65,6 @@ ) from nemo_rl.models.policy.utils import ( configure_dynamo_cache, - configure_expandable_segments, get_gpu_info, get_handle_from_tensor, get_runtime_env_for_policy_worker, @@ -173,9 +172,6 @@ def __init__( # with different order of node_bundles configure_dynamo_cache() - # Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+) - configure_expandable_segments() - # vars used for refit ## will be initialized in prepare_refit_info self.refit_param_info = None @@ -642,6 +638,8 @@ def train( for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) ): + torch.cuda.empty_cache() + with torch.autocast(device_type="cuda", dtype=self.dtype): if self.enable_seq_packing: input_ids = mb.get("input_ids").cuda() diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index de435cfda0..dbe07990ac 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -75,7 +75,6 @@ ) from nemo_rl.models.policy.utils import ( configure_dynamo_cache, - configure_expandable_segments, get_gpu_info, get_handle_from_tensor, get_runtime_env_for_policy_worker, @@ -126,9 +125,6 @@ def __init__( # with different order of node_bundles configure_dynamo_cache() - # Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+) - configure_expandable_segments() - self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call torch.distributed.init_process_group(backend="nccl") @@ -570,6 +566,8 @@ def train( for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) ): + torch.cuda.empty_cache() + with torch.autocast(device_type="cuda", dtype=self.dtype): if self.enable_seq_packing: input_ids = mb.get("input_ids").cuda() diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 6f57762af3..46756c8634 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -118,7 +118,6 @@ ) from nemo_rl.models.policy.utils import ( configure_dynamo_cache, - configure_expandable_segments, get_gpu_info, get_handle_from_tensor, get_megatron_checkpoint_dir, @@ -410,9 +409,6 @@ def __init__( # with different order of node_bundles configure_dynamo_cache() - # Only enable expandable_segments on Hopper and newer architectures (compute capability 9.x+) - configure_expandable_segments() - # cfg["model_name"] is allowed to be either an HF model name or a path to an HF checkpoint # check if hf_model_name is a path hf_model_name = self.cfg["model_name"] diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index e7ad6807d6..42662f7a44 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -165,49 +165,6 @@ def sliding_window_overwrite(model_name: str) -> dict[str, Any]: return overwrite_dict -def configure_expandable_segments() -> None: - """Configure expandable_segments on Hopper and newer architectures (compute capability 9.x+). - - This helps with memory allocation but causes crashes on Ampere GPUs, so we only enable it - on newer architectures. If PYTORCH_CUDA_ALLOC_CONF is already set, preserves existing values. - """ - compute_capability = torch.cuda.get_device_properties(0).major - - if compute_capability >= 9: # Hopper+ - existing_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") - - # Check if expandable_segments is already configured - if "expandable_segments" in existing_conf: - print(f"expandable_segments already configured: {existing_conf}") - # Already configured, don't override - return - - # Add expandable_segments to existing configuration - if existing_conf: - # Append to existing configuration - new_conf = f"{existing_conf},expandable_segments:True" - else: - # Set new configuration - new_conf = "expandable_segments:True" - - print(f"Setting PYTORCH_CUDA_ALLOC_CONF to {new_conf}") - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = new_conf - - else: - ## make sure that expandable_segments is not set to True - if "expandable_segments" in os.environ.get("PYTORCH_CUDA_ALLOC_CONF", ""): - conf_items = os.environ["PYTORCH_CUDA_ALLOC_CONF"].split(",") - for item in conf_items: - if item.strip().startswith("expandable_segments"): - key_value = item.split(":") - if len(key_value) == 2 and key_value[1].strip().lower() == "true": - raise RuntimeError( - "expandable_segments is enabled in PYTORCH_CUDA_ALLOC_CONF, " - "but this is not supported on architectures older than Hopper (compute capability < 9). " - "Please set expandable_segments to False." - ) - - def configure_dynamo_cache() -> None: """Disable dynamo autotune_local_cache. diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index e70af2117f..711b8fd596 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -72,10 +72,11 @@ class LoggerConfig(TypedDict): tensorboard_enabled: bool mlflow_enabled: bool wandb: WandbConfig - tensorboard: TensorboardConfig + tensorboard: NotRequired[TensorboardConfig] mlflow: NotRequired[MLflowConfig] monitor_gpus: bool gpu_monitoring: GPUMonitoringConfig + num_val_samples_to_print: NotRequired[int] class LoggerInterface(ABC): diff --git a/tests/test_suites/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.sh b/tests/test_suites/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.sh new file mode 100755 index 0000000000..8f9e22f337 --- /dev/null +++ b/tests/test_suites/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=100 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=45 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_dpo.py \ + --config $CONFIG_PATH \ + dpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] > 0.6990' \ + 'data["train/loss"]["1"] < 0.6992' \ + 'data["train/loss"]["100"] < 0.60' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index b38557f2e8..21d9610249 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -67,3 +67,6 @@ tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.sh # Short megatron tests/test_suites/llm/dpo-llama3.1-8b-instruct-4n8g-megatrontp2pp2-quick.sh + +# Long dtensor +tests/test_suites/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.sh diff --git a/tests/unit/models/policy/test_utils.py b/tests/unit/models/policy/test_utils.py index 5712985cd3..8fb4d8f8b2 100644 --- a/tests/unit/models/policy/test_utils.py +++ b/tests/unit/models/policy/test_utils.py @@ -14,151 +14,12 @@ import os import unittest.mock -from unittest.mock import MagicMock, patch from nemo_rl.models.policy.utils import ( - configure_expandable_segments, get_megatron_checkpoint_dir, ) -class TestConfigureExpandableSegments(unittest.TestCase): - """Test cases for configure_expandable_segments function.""" - - def setUp(self): - """Set up test environment.""" - # Store original environment variable - self.original_pytorch_cuda_alloc_conf = os.environ.get( - "PYTORCH_CUDA_ALLOC_CONF" - ) - - def tearDown(self): - """Clean up after tests.""" - # Restore original environment variable - if self.original_pytorch_cuda_alloc_conf is not None: - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( - self.original_pytorch_cuda_alloc_conf - ) - elif "PYTORCH_CUDA_ALLOC_CONF" in os.environ: - del os.environ["PYTORCH_CUDA_ALLOC_CONF"] - - @patch("torch.cuda.get_device_properties") - def test_hopper_gpu_no_existing_config(self, mock_get_device_properties): - """Test Hopper+ GPU (compute capability >= 9) with no existing PYTORCH_CUDA_ALLOC_CONF.""" - # Mock GPU properties for Hopper+ architecture - mock_device_properties = MagicMock() - mock_device_properties.major = 9 - mock_get_device_properties.return_value = mock_device_properties - - # Ensure no existing config - if "PYTORCH_CUDA_ALLOC_CONF" in os.environ: - del os.environ["PYTORCH_CUDA_ALLOC_CONF"] - - # Call the function - configure_expandable_segments() - - # Verify the environment variable was set correctly - self.assertEqual( - os.environ["PYTORCH_CUDA_ALLOC_CONF"], "expandable_segments:True" - ) - - @patch("torch.cuda.get_device_properties") - def test_hopper_gpu_with_existing_config(self, mock_get_device_properties): - """Test Hopper+ GPU with existing PYTORCH_CUDA_ALLOC_CONF.""" - # Mock GPU properties for Hopper+ architecture - mock_device_properties = MagicMock() - mock_device_properties.major = 9 - mock_get_device_properties.return_value = mock_device_properties - - # Set existing config - existing_config = "max_split_size_mb:128" - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = existing_config - - # Call the function - configure_expandable_segments() - - # Verify the environment variable was updated correctly - expected_config = f"{existing_config},expandable_segments:True" - self.assertEqual(os.environ["PYTORCH_CUDA_ALLOC_CONF"], expected_config) - - @patch("torch.cuda.get_device_properties") - def test_hopper_gpu_already_configured(self, mock_get_device_properties): - """Test Hopper+ GPU with existing config that already has expandable_segments.""" - # Mock GPU properties for Hopper+ architecture - mock_device_properties = MagicMock() - mock_device_properties.major = 9 - mock_get_device_properties.return_value = mock_device_properties - - # Set existing config with expandable_segments already present - existing_config = "max_split_size_mb:128,expandable_segments:False" - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = existing_config - - # Call the function - configure_expandable_segments() - - # Verify the environment variable was not changed - self.assertEqual(os.environ["PYTORCH_CUDA_ALLOC_CONF"], existing_config) - - @patch("torch.cuda.get_device_properties") - def test_ampere_gpu_no_config_change(self, mock_get_device_properties): - """Test Ampere GPU (compute capability < 9) should not modify config.""" - # Mock GPU properties for Ampere architecture - mock_device_properties = MagicMock() - mock_device_properties.major = 8 # Ampere - mock_get_device_properties.return_value = mock_device_properties - - # Set existing config - existing_config = "max_split_size_mb:128" - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = existing_config - - # Call the function - configure_expandable_segments() - - # Verify the environment variable was not changed - self.assertEqual(os.environ["PYTORCH_CUDA_ALLOC_CONF"], existing_config) - - @patch("torch.cuda.get_device_properties") - def test_ampere_gpu_no_existing_config(self, mock_get_device_properties): - """Test Ampere GPU with no existing config should not set anything.""" - # Mock GPU properties for Ampere architecture - mock_device_properties = MagicMock() - mock_device_properties.major = 8 # Ampere - mock_get_device_properties.return_value = mock_device_properties - - # Ensure no existing config - if "PYTORCH_CUDA_ALLOC_CONF" in os.environ: - del os.environ["PYTORCH_CUDA_ALLOC_CONF"] - - # Call the function - configure_expandable_segments() - - # Verify the environment variable was not set - self.assertNotIn("PYTORCH_CUDA_ALLOC_CONF", os.environ) - - @patch("torch.cuda.get_device_properties") - def test_ampere_gpu_with_expandable_segments_true_raises_error( - self, mock_get_device_properties - ): - """Test Ampere GPU with expandable_segments:True in config raises RuntimeError.""" - # Mock GPU properties for Ampere architecture - mock_device_properties = MagicMock() - mock_device_properties.major = 8 # Ampere - mock_get_device_properties.return_value = mock_device_properties - - # Set config with expandable_segments:True - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" - - # Call the function and expect RuntimeError - with self.assertRaises(RuntimeError) as context: - configure_expandable_segments() - - # Verify the error message - self.assertIn("expandable_segments is enabled", str(context.exception)) - self.assertIn( - "not supported on architectures older than Hopper", str(context.exception) - ) - - class TestGetMegatronCheckpointDir: """Test cases for the get_megatron_checkpoint_dir function."""