diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 3bf64b2652..5edbd6816c 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -359,6 +359,19 @@ def init_collective(self, data: int, ip: str, port: int, world_size: int) -> Non ), ) + async def init_collective_async( + self, data: int, ip: str, port: int, world_size: int + ) -> None: + await self.llm.collective_rpc( + "init_collective", + args=( + data, + ip, + port, + world_size, + ), + ) + def llm(self): return self.llm @@ -979,7 +992,7 @@ def update_weights_from_collective(self, data: dict[str, Any]) -> bool: if self.cfg["vllm_cfg"]["async_engine"]: raise RuntimeError( - "update_weights_from_collective cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead." + "update_weights_from_collective can only be used with async_engine=False. Use update_weights_from_collective_async instead." ) result_or_coro = self.llm.collective_rpc( @@ -1000,12 +1013,72 @@ def update_weights_from_collective(self, data: dict[str, Any]) -> bool: traceback.print_exc() return False + async def update_weights_from_collective_async(self, data: dict[str, Any]) -> bool: + """Async version of update_weights_from_collective.""" + try: + assert self.llm is not None, ( + "Attempting to update weights with either an uninitialized vLLM or non-model-owner" + ) + + if not self.cfg["vllm_cfg"]["async_engine"]: + raise RuntimeError( + "update_weights_from_collective_async can only be used with async_engine=True. Use update_weights_from_collective instead." + ) + + result_or_coro = await self.llm.collective_rpc( + "update_weights_from_collective", args=(data,) + ) + + if asyncio.iscoroutine(result_or_coro): + worker_results = await result_or_coro + else: + worker_results = result_or_coro + + worker_result = worker_results[0] + + if not worker_result: + print( + f"Error: Worker failed to update weights. Result: {worker_result}" + ) + return False + return True + except Exception as e: + print(f"Exception during collective_rpc for weight update: {e}") + import traceback + + traceback.print_exc() + return False + def reset_prefix_cache(self): """Reset the prefix cache of vLLM engine.""" + assert self.llm is not None, ( + "Attempting to reset prefix cache with either an uninitialized vLLM or non-model-owner" + ) + + if self.cfg["vllm_cfg"]["async_engine"]: + raise RuntimeError( + "reset_prefix_cache can only be used with async_engine=False. Use reset_prefix_cache_async instead." + ) + self.llm.llm_engine.reset_prefix_cache() gc.collect() torch.cuda.empty_cache() + async def reset_prefix_cache_async(self): + """Async version of reset_prefix_cache.""" + assert self.llm is not None, ( + "Attempting to reset prefix cache with either an uninitialized vLLM or non-model-owner" + ) + + if not self.cfg["vllm_cfg"]["async_engine"]: + raise RuntimeError( + "reset_prefix_cache_async can only be used with async_engine=True. Use reset_prefix_cache instead." + ) + + await self.llm.reset_prefix_cache() + gc.collect() + torch.cuda.empty_cache() + def sleep(self): """Put the vLLM engine to sleep.""" assert self.llm is not None, ( @@ -1305,6 +1378,13 @@ def init_collective( if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") + # Choose the appropriate method based on async_engine setting + method_name = ( + "init_collective_async" + if self.cfg["vllm_cfg"]["async_engine"] + else "init_collective" + ) + # Prepare rank total_workers = len(self.worker_group.workers) if self.dp_size == 0: @@ -1316,7 +1396,7 @@ def init_collective( # Send world_size and rank for init collective to all workers futures = self.worker_group.run_all_workers_multiple_data( - "init_collective", + method_name, data=rank_prefix_list, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], common_kwargs={"ip": ip, "port": port, "world_size": world_size}, @@ -1558,12 +1638,16 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool: try: # Choose the appropriate method based on setting # non-colocated only needs reset prefix cache, no need to sleep. - if not self.cfg["colocated"]["enabled"]: - method_name = "reset_prefix_cache" - else: + if self.cfg["colocated"]["enabled"]: method_name = ( "sleep_async" if self.cfg["vllm_cfg"]["async_engine"] else "sleep" ) + else: + method_name = ( + "reset_prefix_cache_async" + if self.cfg["vllm_cfg"]["async_engine"] + else "reset_prefix_cache" + ) # Use run_all_workers_single_data for methods that don't need data futures = self.worker_group.run_all_workers_single_data( method_name, @@ -1636,9 +1720,16 @@ def update_weights_from_collective( if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") + # Choose the appropriate method based on async_engine setting + method_name = ( + "update_weights_from_collective_async" + if self.cfg["vllm_cfg"]["async_engine"] + else "update_weights_from_collective" + ) + # Use run_all_workers_single_data to send data to all workers futures = self.worker_group.run_all_workers_single_data( - "update_weights_from_collective", + method_name, data=info, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], ) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 266be74264..dc1de1b123 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -379,13 +379,13 @@ def test_vllm_policy_generation(policy, test_input_data, tokenizer): ) -async def _generate_async(vllm_policy, tokenizer, test_input_data): +async def _generate_async(vllm_policy, tokenizer, test_input_data, greedy=False): collected_indexed_outputs = [] # generate_async is restricted to handle only single samples input_generator = test_input_data.make_microbatch_iterator(microbatch_size=1) for single_item_input in input_generator: async for original_idx, single_item_output in vllm_policy.generate_async( - single_item_input + single_item_input, greedy=greedy ): collected_indexed_outputs.append((original_idx, single_item_output)) @@ -691,7 +691,7 @@ async def test_vllm_generation_with_hf_training( print("Using vLLM policy for fast generation...") if async_engine: generation_results = await _generate_async( - vllm_policy, tokenizer, test_input_data + vllm_policy, tokenizer, test_input_data, greedy=True ) else: generation_results = vllm_policy.generate(test_input_data, greedy=True) @@ -1174,11 +1174,14 @@ def test_vllm_non_divisible_batch_handling(policy): ) -def test_vllm_refit_non_collocated_handles_update( +@pytest.mark.asyncio +@pytest.mark.parametrize("async_engine", [True, False]) +async def test_vllm_refit_non_collocated_update_weights( policy_cluster_separate, generation_cluster_separate, tokenizer, test_input_data, + async_engine, ): if ( policy_cluster_separate.num_gpus_per_node < 1 @@ -1197,6 +1200,7 @@ def test_vllm_refit_non_collocated_handles_update( # Create VllmGeneration policy on its own cluster vllm_config = deepcopy(basic_vllm_test_config) vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) + vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["tensor_parallel_size"] = 1 vllm_config["colocated"]["enabled"] = False vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config) @@ -1213,7 +1217,12 @@ def test_vllm_refit_non_collocated_handles_update( ) # test generate - outputs = vllm_generation.generate(test_input_data, greedy=True) + if async_engine: + outputs = await _generate_async( + vllm_generation, tokenizer, test_input_data, greedy=True + ) + else: + outputs = vllm_generation.generate(test_input_data, greedy=True) output_ids = outputs["output_ids"] generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) assert generated_texts == [