From 56aac874c4d802acf9615a5d1647fd0774cdc496 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Tue, 8 Jul 2025 09:51:14 -0700 Subject: [PATCH 1/4] feat: optimize refit by reducing set of IPC handles sent to each device Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm.py | 36 ++++++++++++++++++++--- nemo_rl/models/generation/vllm_backend.py | 28 ++++++++++++------ 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index bb689ac7f4..ea29acce2a 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -346,6 +346,7 @@ def _patch_vllm_init_workers_ray(): self.llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**llm_kwargs)) else: self.llm = vllm.LLM(**llm_kwargs) + self.vllm_device_ids = self.report_device_id() def init_collective( self, rank_prefix: int, ip: str, port: int, world_size: int @@ -1030,9 +1031,36 @@ def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: "update_weights_from_ipc_handles cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." ) - result_or_coro = self.llm.collective_rpc( - "update_weights_from_ipc_handles", args=(ipc_handles,) - ) + if self.tensor_parallel_size == 1: + # UniProcExecutor + assert len(self.vllm_device_ids) == 1 + result_or_coro = self.llm.collective_rpc( + "update_weights_from_local_ipc_handles", + args=(ipc_handles[self.vllm_device_ids[0]],), + ) + else: + """ + DO NOT USE VLLM's collective_rpc: This code causes duplicate IPC data transfer across Ray workers, + leading to unnecessary network serialization overhead and potential performance degradation. + + result_or_coro = self.llm.collective_rpc( + "update_weights_from_global_ipc_handles", args=(ipc_handles,) + ) + """ + ray_worker_outputs = [] + # MultiProcExecutor + for worker, device_id in zip( + self.llm.llm_engine.model_executor.workers, self.vllm_device_ids + ): + ray_worker_outputs.append( + worker.execute_method.remote( + "update_weights_from_local_ipc_handles", ipc_handles[device_id] + ) + ) + + # Gather the results + result_or_coro = ray.get(ray_worker_outputs) + worker_result = result_or_coro[0] if not worker_result: @@ -1070,7 +1098,7 @@ async def update_weights_from_ipc_handles_async( ) result_or_coro = await self.llm.collective_rpc( - "update_weights_from_ipc_handles", args=(ipc_handles,) + "update_weights_from_global_ipc_handles", args=(ipc_handles,) ) if asyncio.iscoroutine(result_or_coro): diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 01dc68146b..65b05cbea9 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -50,25 +50,35 @@ def report_device_id(self) -> str: from nemo_rl.utils.nvml import get_device_uuid return get_device_uuid(self.device.index) + + def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): + """Update weights from global IPC handles. - def update_weights_from_ipc_handles(self, ipc_handles): - """Update weights from IPC handles. + Args: + global_device_ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + + Returns: + bool: True if weights were successfully updated. + """ + device_uuid = self.report_device_id() + local_device_ipc_handles = global_device_ipc_handles[device_uuid] + return self.update_weights_from_local_ipc_handles(local_device_ipc_handles) + + def update_weights_from_local_ipc_handles(self, local_device_ipc_handles): + """Update weights from local IPC handles. Args: - ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + local_device_ipc_handles (dict): parameter IPC handles for local device. Returns: bool: True if weights were successfully updated. """ try: - # Get handles for this device - device_uuid = self.report_device_id() - handles = ipc_handles[device_uuid] - is_tensor_packed = handles[0] + is_tensor_packed = local_device_ipc_handles[0] if is_tensor_packed: - _, all_handles, tensor_metadata = handles + _, all_handles, tensor_metadata = local_device_ipc_handles else: - _, name_and_handle_list = handles + _, name_and_handle_list = local_device_ipc_handles device_id = self.device.index weights = [] From 4ae37a08666d7e7eae1f3a41d821539e94a98362 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Mon, 14 Jul 2025 07:57:51 +0000 Subject: [PATCH 2/4] fix lint fix lint Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm.py | 3 ++- nemo_rl/models/generation/vllm_backend.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index ea29acce2a..8b73b3152c 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1054,7 +1054,8 @@ def update_weights_from_ipc_handles(self, ipc_handles: dict[str, Any]) -> bool: ): ray_worker_outputs.append( worker.execute_method.remote( - "update_weights_from_local_ipc_handles", ipc_handles[device_id] + "update_weights_from_local_ipc_handles", + ipc_handles[device_id], ) ) diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 65b05cbea9..3d6ed0253c 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -50,7 +50,7 @@ def report_device_id(self) -> str: from nemo_rl.utils.nvml import get_device_uuid return get_device_uuid(self.device.index) - + def update_weights_from_global_ipc_handles(self, global_device_ipc_handles): """Update weights from global IPC handles. From c70e0315560cbf73a6c11c1e945a2e753f8a7635 Mon Sep 17 00:00:00 2001 From: Zhiyu Li Date: Mon, 14 Jul 2025 17:50:28 +0000 Subject: [PATCH 3/4] add TODO Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 8b73b3152c..fa07030863 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1098,6 +1098,7 @@ async def update_weights_from_ipc_handles_async( "update_weights_from_ipc_handles_async can only be used with async_engine=True. Use update_weights_from_ipc_handles instead." ) + # TODO: switch to update_weights_from_local_ipc_handles for better performance once collectively report_device_id is supported in asyncLLM initialization result_or_coro = await self.llm.collective_rpc( "update_weights_from_global_ipc_handles", args=(ipc_handles,) ) From 22f18b99cc3f5c4d217a3af7526a686385adfe2c Mon Sep 17 00:00:00 2001 From: yuki <48991475+yuki-666@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:56:47 +0800 Subject: [PATCH 4/4] feat: add post_init to support get device_id in async vllm init (#668) Signed-off-by: Yuki Huang Signed-off-by: Zhiyu Li --- nemo_rl/models/generation/vllm.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index fa07030863..a99022e30f 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -346,7 +346,16 @@ def _patch_vllm_init_workers_ray(): self.llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**llm_kwargs)) else: self.llm = vllm.LLM(**llm_kwargs) - self.vllm_device_ids = self.report_device_id() + + # will be initialized in post_init + # used in update_weights_from_ipc_handles + self.vllm_device_ids = None + + def post_init(self): + self.vllm_device_ids = self.report_device_id() + + async def post_init_async(self): + self.vllm_device_ids = await self.report_device_id_async() def init_collective( self, rank_prefix: int, ip: str, port: int, world_size: int @@ -1386,6 +1395,10 @@ def __init__( env_vars=env_vars, ) + # Call some collective rpc functions in VllmGenerationWorker when initializing the vLLM engine + # This is necessary for async engine to work + self._post_init() + # Number of data parallel groups is the number of tied worker groups self.dp_size = self.worker_group.dp_size @@ -1526,6 +1539,19 @@ def _report_device_id(self) -> list[list[str]]: results = ray.get(futures) return results + def _post_init(self): + # Choose the appropriate method based on async_engine setting + method_name = ( + "post_init_async" if self.cfg["vllm_cfg"]["async_engine"] else "post_init" + ) + # Use run_all_workers_single_data for methods that don't need data + futures = self.worker_group.run_all_workers_single_data( + method_name, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] + ) + # Wait for all futures to complete + results = ray.get(futures) + return results + def init_collective( self, ip: str, port: int, world_size: int ) -> list[ray.ObjectRef]: