diff --git a/recipe/one_step_off_policy/fsdp_workers.py b/recipe/one_step_off_policy/fsdp_workers.py index 0aa9dbbe004..d2f65d45ad8 100644 --- a/recipe/one_step_off_policy/fsdp_workers.py +++ b/recipe/one_step_off_policy/fsdp_workers.py @@ -95,7 +95,10 @@ def sync_rollout_weights(self): if torch.distributed.get_rank() == 0: tensor.copy_(origin_data) - collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if device_name == "npu": + self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) + else: + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") if self._is_rollout: if rollout_name == "vllm": diff --git a/recipe/one_step_off_policy/megatron_workers.py b/recipe/one_step_off_policy/megatron_workers.py index c2a2407939e..7b11069e6e2 100644 --- a/recipe/one_step_off_policy/megatron_workers.py +++ b/recipe/one_step_off_policy/megatron_workers.py @@ -23,7 +23,10 @@ from recipe.one_step_off_policy.distributed_util import vllm_stateless_init_process_group from verl.single_controller.base.decorator import Dispatch, register -from verl.utils.device import get_torch_device +from verl.utils.device import ( + get_device_name, + get_torch_device, +) from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu from verl.utils.ray_utils import get_event_loop from verl.workers.megatron_workers import ( @@ -36,6 +39,7 @@ logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +device_name = get_device_name() __all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker", "RewardModelWorker"] @@ -89,7 +93,10 @@ def sync_rollout_weights(self): if self._is_actor and torch.distributed.get_rank() == 0: tensor.copy_(weight) - collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if device_name == "npu": + self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) + else: + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") if self._is_rollout: if rollout_name == "vllm": diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py index c3890f61bb9..a48d7fd4164 100644 --- a/recipe/one_step_off_policy/ray_trainer.py +++ b/recipe/one_step_off_policy/ray_trainer.py @@ -255,20 +255,37 @@ def _init_models(self): self._create_weight_sync_group() def _create_weight_sync_group(self): - # TODO: NPU support from verl.utils.device import get_nccl_backend actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers n_workers = len(actor_rollout_workers) - # Create Ray collective group for fallback communication - collective.create_collective_group( - actor_rollout_workers, - n_workers, - list(range(0, n_workers)), - backend=get_nccl_backend(), - group_name="actor_rollout", - ) + if self.device_name == "npu": + master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote()) + master_port = ray.get(self.actor_wg.workers[0]._get_free_port.remote()) + self.actor_wg.create_weight_sync_group( + master_address, + master_port, + 0, + n_workers, + ) + ray.get( + self.rollout_wg.create_weight_sync_group( + master_address, + master_port, + len(self.actor_wg.workers), + n_workers, + ) + ) + else: + # Create Ray collective group for fallback communication + collective.create_collective_group( + actor_rollout_workers, + n_workers, + list(range(0, n_workers)), + backend=get_nccl_backend(), + group_name="actor_rollout", + ) def _init_async_rollout_manager(self): # create async rollout manager and request scheduler