diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 34d996e8624..b932657e699 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -65,7 +65,7 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with validation and saving run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/run_ppo_trainer_megatron.sh + ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/run_ppo_trainer_megatron.sh - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) after resuming run: | ray stop --force @@ -107,7 +107,7 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) run: | ray stop --force - SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh + ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) run: | ray stop --force @@ -149,7 +149,7 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming run: | ray stop --force @@ -306,3 +306,4 @@ jobs: run: | rm -rf checkpoints + diff --git a/tests/e2e/run_ppo_trainer_megatron.sh b/tests/e2e/run_ppo_trainer_megatron.sh index a70db50ad99..82b0582c3da 100644 --- a/tests/e2e/run_ppo_trainer_megatron.sh +++ b/tests/e2e/run_ppo_trainer_megatron.sh @@ -55,6 +55,20 @@ RM_VPP=${RM_VPP:-$COMMON_VPP} RM_CP=${RM_CP:-$COMMON_CP} RM_TP=${RM_TP:-$TRAIN_TP} +ALL_OFFLOAD=${ALL_OFFLOAD:-False} +COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} +COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} +COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} + +ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} +CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} +CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} +RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} + CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra'] SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0} if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then @@ -81,6 +95,9 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \ actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \ actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \ + actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ + actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ @@ -95,6 +112,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ + actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ critic.optim.lr=2e-5 \ critic.model.path="${MODEL_PATH}" \ critic.model.enable_gradient_checkpointing=False \ @@ -104,6 +122,9 @@ python3 -m verl.trainer.main_ppo --config-path=config \ critic.megatron.context_parallel_size=$CRITIC_CP \ critic.megatron.tensor_model_parallel_size=$CRITIC_TP \ critic.checkpoint.contents=$CHECKPOINT_CONTENTS \ + critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \ + critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \ + critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ reward_model.enable=True \ reward_model.model.path="${MODEL_PATH}" \ reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ @@ -111,6 +132,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \ reward_model.megatron.context_parallel_size=$RM_CP \ reward_model.megatron.tensor_model_parallel_size=$RM_TP \ + reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \ algorithm.use_kl_in_reward=False \ algorithm.kl_penalty=kl \ algorithm.kl_ctrl.kl_coef=0.001 \ diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 2b1a0941594..636e287b8b7 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -227,8 +227,6 @@ reward_model: strategy: megatron megatron: param_offload: False - grad_offload: False - optimizer_offload: False tensor_model_parallel_size: 1 expert_model_parallel_size: 1 expert_tensor_parallel_size: null diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 841d1315a82..ed0a1453ee8 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -25,7 +25,7 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.enums import ModelType -from megatron.core.optimizer import OptimizerConfig +from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import Float16Module from megatron.core.utils import get_attr_wrapped_model @@ -296,12 +296,18 @@ def load_megatron_model_to_gpu(models, load_grad=True): @torch.no_grad() def offload_megatron_copy_params(optimizers): """ - Offload optimizer parameters to CPU + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. Args: - optimizers: The optimizer containing parameter groups to offload + optimizers: The optimizer or ChainedOptimizer instance. """ + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + def offload_tensor_to_cpu(tensor): if tensor is None: return @@ -321,21 +327,27 @@ def offload_group_to_cpu(group): else: offload_tensor_to_cpu(group) - # Offload all parameter groups to CPU + # Offload all parameter groups to CPU for each underlying optimizer - if hasattr(optimizers, "shard_fp32_from_float16_groups"): - offload_group_to_cpu(optimizers.shard_fp32_from_float16_groups) + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) @torch.no_grad() def load_megatron_copy_params(optimizers): """ - Load optimizer parameters back to GPU + Load optimizer parameters back to GPU. Handles ChainedOptimizer. Args: - optimizers: The optimizer containing parameter groups to load + optimizers: Optimizer or ChainedOptimizer instance. """ + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + def load_tensor_to_gpu(tensor): if tensor is None: return @@ -356,36 +368,49 @@ def load_group_to_gpu(group): else: load_tensor_to_gpu(group) - # Load all parameter groups to GPU + # Load all parameter groups to GPU for each underlying optimizer - if hasattr(optimizers, "shard_fp32_from_float16_groups"): - load_group_to_gpu(optimizers.shard_fp32_from_float16_groups) + for _opt in _iter_opts(optimizers): + if hasattr(_opt, "shard_fp32_from_float16_groups"): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) @torch.no_grad() def offload_megatron_optimizer(optimizers): - offload_megatron_copy_params(optimizers) - opt_state_dict_values = optimizers.optimizer.state.values() - for v in opt_state_dict_values: - if "exp_avg" in v: - v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) - if "exp_avg_sq" in v: - v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) - gc.collect() - torch.cuda.empty_cache() + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) + gc.collect() + torch.cuda.empty_cache() @torch.no_grad() def load_megatron_optimizer(optimizers): - load_megatron_copy_params(optimizers) - opt_state_dict_values = optimizers.optimizer.state.values() - for v in opt_state_dict_values: - if "exp_avg" in v: - v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True) - if "exp_avg_sq" in v: - v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True) - gc.collect() - torch.cuda.empty_cache() + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if "exp_avg" in v: + v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True) + if "exp_avg_sq" in v: + v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True) + gc.collect() + torch.cuda.empty_cache() def print_rank_0(message):