Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipe/one_step_off_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def sync_rollout_weights(self):
if torch.distributed.get_rank() == 0:
tensor.copy_(origin_data)

collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
if device_name == "npu":
self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
else:
collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")

if self._is_rollout:
if rollout_name == "vllm":
Expand Down
11 changes: 9 additions & 2 deletions recipe/one_step_off_policy/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

from recipe.one_step_off_policy.distributed_util import vllm_stateless_init_process_group
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import get_torch_device
from verl.utils.device import (
get_device_name,
get_torch_device,
)
from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu
from verl.utils.ray_utils import get_event_loop
from verl.workers.megatron_workers import (
Expand All @@ -36,6 +39,7 @@
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

device_name = get_device_name()

__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker", "RewardModelWorker"]

Expand Down Expand Up @@ -89,7 +93,10 @@ def sync_rollout_weights(self):
if self._is_actor and torch.distributed.get_rank() == 0:
tensor.copy_(weight)

collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
if device_name == "npu":
self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
else:
collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")

if self._is_rollout:
if rollout_name == "vllm":
Expand Down
35 changes: 26 additions & 9 deletions recipe/one_step_off_policy/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,37 @@ def _init_models(self):
self._create_weight_sync_group()

def _create_weight_sync_group(self):
# TODO: NPU support
from verl.utils.device import get_nccl_backend

actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
n_workers = len(actor_rollout_workers)

# Create Ray collective group for fallback communication
collective.create_collective_group(
actor_rollout_workers,
n_workers,
list(range(0, n_workers)),
backend=get_nccl_backend(),
group_name="actor_rollout",
)
if self.device_name == "npu":
master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote())
master_port = ray.get(self.actor_wg.workers[0]._get_free_port.remote())
self.actor_wg.create_weight_sync_group(
master_address,
master_port,
0,
n_workers,
)
ray.get(
self.rollout_wg.create_weight_sync_group(
master_address,
master_port,
len(self.actor_wg.workers),
n_workers,
)
)
else:
# Create Ray collective group for fallback communication
collective.create_collective_group(
actor_rollout_workers,
n_workers,
list(range(0, n_workers)),
backend=get_nccl_backend(),
group_name="actor_rollout",
)

def _init_async_rollout_manager(self):
# create async rollout manager and request scheduler
Expand Down