diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index bb689ac7f4..a99022e30f 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -347,6 +347,16 @@ def _patch_vllm_init_workers_ray(): else: self.llm = vllm.LLM(**llm_kwargs) + # will be initialized in post_init + # used in update_weights_from_ipc_handles + self.vllm_device_ids = None + + def post_init(self): + self.vllm_device_ids = self.report_device_id() + + async def post_init_async(self): + self.vllm_device_ids = await self.report_device_id_async() + def init_collective( self, rank_prefix: int, ip: str, port: int, world_size: int ) -> None: @@ -1030,9 +1040,37 @@ def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." ) - result_or_coro = self.llm.collective_rpc( - "update_weights_from_ipc_handles", args=(ipc_handles,) - ) + if self.tensor_parallel_size == 1: + # UniProcExecutor + assert len(self.vllm_device_ids) == 1 + result_or_coro = self.llm.collective_rpc( + "update_weights_from_local_ipc_handles", + args=(ipc_handles[self.vllm_device_ids[0]],), + ) + else: + """ + DO NOT USE VLLM's collective_rpc: This code causes duplicate IPC data transfer across Ray workers, + leading to unnecessary network serialization overhead and potential performance degradation. + + result_or_coro = self.llm.collective_rpc( + "update_weights_from_global_ipc_handles", args=(ipc_handles,) + ) + """ + ray_worker_outputs = [] + # MultiProcExecutor + for worker, device_id in zip( + self.llm.llm_engine.model_executor.workers, self.vllm_device_ids + ): + ray_worker_outputs.append( + worker.execute_method.remote( + "update_weights_from_local_ipc_handles", + ipc_handles[device_id], + ) + ) + + # Gather the results + result_or_coro = ray.get(ray_worker_outputs) + worker_result = result_or_coro[0] if not worker_result: @@ -1069,8 +1107,9 @@ async def update_weights_from_ipc_handles_async( "update_weights_from_ipc_handles_async can only be used with async_engine=True. Use update_weights_from_ipc_handles instead." ) + # TODO: switch to update_weights_from_local_ipc_handles for better performance once collectively report_device_id is supported in asyncLLM initialization result_or_coro = await self.llm.collective_rpc( - "update_weights_from_ipc_handles", args=(ipc_handles,) + "update_weights_from_global_ipc_handles", args=(ipc_handles,) ) if asyncio.iscoroutine(result_or_coro): @@ -1356,6 +1395,10 @@ def __init__( env_vars=env_vars, ) + # Call some collective rpc functions in VllmGenerationWorker when initializing the vLLM engine + # This is necessary for async engine to work + self._post_init() + # Number of data parallel groups is the number of tied worker groups self.dp_size = self.worker_group.dp_size @@ -1496,6 +1539,19 @@ def _report_device_id(self) -> list[list[str]]: results = ray.get(futures) return results + def _post_init(self): + # Choose the appropriate method based on async_engine setting + method_name = ( + "post_init_async" if self.cfg["vllm_cfg"]["async_engine"] else "post_init" + ) + # Use run_all_workers_single_data for methods that don't need data + futures = self.worker_group.run_all_workers_single_data( + method_name, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] + ) + # Wait for all futures to complete + results = ray.get(futures) + return results + def init_collective( self, ip: str, port: int, world_size: int ) -> list[ray.ObjectRef]: diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 01dc68146b..3d6ed0253c 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -51,24 +51,34 @@ def report_device_id(self) -> str: return get_device_uuid(self.device.index) - def update_weights_from_ipc_handles(self, ipc_handles): - """Update weights from IPC handles. + def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): + """Update weights from global IPC handles. Args: - ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + global_device_ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + + Returns: + bool: True if weights were successfully updated. + """ + device_uuid = self.report_device_id() + local_device_ipc_handles = global_device_ipc_handles[device_uuid] + return self.update_weights_from_local_ipc_handles(local_device_ipc_handles) + + def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): + """Update weights from local IPC handles. + + Args: + local_device_ipc_handles (dict): parameter IPC handles for local device. Returns: bool: True if weights were successfully updated. """ try: - # Get handles for this device - device_uuid = self.report_device_id() - handles = ipc_handles[device_uuid] - is_tensor_packed = handles[0] + is_tensor_packed = local_device_ipc_handles[0] if is_tensor_packed: - _, all_handles, tensor_metadata = handles + _, all_handles, tensor_metadata = local_device_ipc_handles else: - _, name_and_handle_list = handles + _, name_and_handle_list = local_device_ipc_handles device_id = self.device.index weights = []