diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index fa07030863..a99022e30f 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -346,7 +346,16 @@ def _patch_vllm_init_workers_ray(): self.llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**llm_kwargs)) else: self.llm = vllm.LLM(**llm_kwargs) - self.vllm_device_ids = self.report_device_id() + + # will be initialized in post_init + # used in update_weights_from_ipc_handles + self.vllm_device_ids = None + + def post_init(self): + self.vllm_device_ids = self.report_device_id() + + async def post_init_async(self): + self.vllm_device_ids = await self.report_device_id_async() def init_collective( self, rank_prefix: int, ip: str, port: int, world_size: int @@ -1386,6 +1395,10 @@ def __init__( env_vars=env_vars, ) + # Call some collective rpc functions in VllmGenerationWorker when initializing the vLLM engine + # This is necessary for async engine to work + self._post_init() + # Number of data parallel groups is the number of tied worker groups self.dp_size = self.worker_group.dp_size @@ -1526,6 +1539,19 @@ def _report_device_id(self) -> list[list[str]]: results = ray.get(futures) return results + def _post_init(self): + # Choose the appropriate method based on async_engine setting + method_name = ( + "post_init_async" if self.cfg["vllm_cfg"]["async_engine"] else "post_init" + ) + # Use run_all_workers_single_data for methods that don't need data + futures = self.worker_group.run_all_workers_single_data( + method_name, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] + ) + # Wait for all futures to complete + results = ray.get(futures) + return results + def init_collective( self, ip: str, port: int, world_size: int ) -> list[ray.ObjectRef]: