diff --git a/tests/npu/run_qwen2_5_05b_grpo.sh b/tests/npu/run_qwen2_5_05b_grpo.sh index ed44063d59d..6ccaf7b4379 100644 --- a/tests/npu/run_qwen2_5_05b_grpo.sh +++ b/tests/npu/run_qwen2_5_05b_grpo.sh @@ -39,4 +39,6 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=1 $@ \ No newline at end of file + trainer.total_epochs=1 \ + trainer.total_training_steps=2 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/npu/run_qwen2_5_32b_grpo.sh b/tests/npu/run_qwen2_5_32b_grpo.sh index d83e36b843f..461b27b80fd 100644 --- a/tests/npu/run_qwen2_5_32b_grpo.sh +++ b/tests/npu/run_qwen2_5_32b_grpo.sh @@ -40,4 +40,5 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=2 \ trainer.save_freq=-1 \ trainer.test_freq=10 \ - trainer.total_epochs=15 $@ \ No newline at end of file + trainer.total_epochs=15 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/tests/npu/run_qwen2_5_7b_grpo.sh b/tests/npu/run_qwen2_5_7b_grpo.sh index 8ee7445b469..ff173e2b5f6 100644 --- a/tests/npu/run_qwen2_5_7b_grpo.sh +++ b/tests/npu/run_qwen2_5_7b_grpo.sh @@ -41,4 +41,5 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=5 $@ \ No newline at end of file + trainer.total_epochs=5 \ + trainer.device=npu $@ \ No newline at end of file diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 6c4e81f0695..1741524eee4 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -295,6 +295,7 @@ trainer: max_critic_ckpt_to_keep: null # The timeout for ray worker group to wait for the register center to be ready ray_wait_register_center_timeout: 300 + device: cuda ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7a6d65e019a..e1fd4761524 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -253,6 +253,7 @@ trainer: max_critic_ckpt_to_keep: null # The timeout for ray worker group to wait for the register center to be ready ray_wait_register_center_timeout: 300 + device: cuda ray_init: num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 77e20343993..6cd00b83302 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -22,7 +22,6 @@ from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.reward import load_reward_manager -from verl.utils.device import is_cuda_available def get_custom_reward_fn(config): @@ -179,7 +178,7 @@ def run(self, config): val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, - device_name="cuda" if is_cuda_available else "npu", + device_name=config.trainer.device, ) trainer.init_workers() trainer.fit()