Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 97 additions & 6 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,19 @@ def init_collective(self, data: int, ip: str, port: int, world_size: int) -> Non
),
)

async def init_collective_async(
self, data: int, ip: str, port: int, world_size: int
) -> None:
await self.llm.collective_rpc(
"init_collective",
args=(
data,
ip,
port,
world_size,
),
)

def llm(self):
return self.llm

Expand Down Expand Up @@ -979,7 +992,7 @@ def update_weights_from_collective(self, data: dict[str, Any]) -> bool:

if self.cfg["vllm_cfg"]["async_engine"]:
raise RuntimeError(
"update_weights_from_collective cannot be used with async_engine=True. Use update_weights_from_ipc_handles_async instead."
"update_weights_from_collective can only be used with async_engine=False. Use update_weights_from_collective_async instead."
)

result_or_coro = self.llm.collective_rpc(
Expand All @@ -1000,12 +1013,72 @@ def update_weights_from_collective(self, data: dict[str, Any]) -> bool:
traceback.print_exc()
return False

async def update_weights_from_collective_async(self, data: dict[str, Any]) -> bool:
"""Async version of update_weights_from_collective."""
try:
assert self.llm is not None, (
"Attempting to update weights with either an uninitialized vLLM or non-model-owner"
)

if not self.cfg["vllm_cfg"]["async_engine"]:
raise RuntimeError(
"update_weights_from_collective_async can only be used with async_engine=True. Use update_weights_from_collective instead."
)

result_or_coro = await self.llm.collective_rpc(
"update_weights_from_collective", args=(data,)
)

if asyncio.iscoroutine(result_or_coro):
worker_results = await result_or_coro
else:
worker_results = result_or_coro

worker_result = worker_results[0]

if not worker_result:
print(
f"Error: Worker failed to update weights. Result: {worker_result}"
)
return False
return True
except Exception as e:
print(f"Exception during collective_rpc for weight update: {e}")
import traceback

traceback.print_exc()
return False

def reset_prefix_cache(self):
"""Reset the prefix cache of vLLM engine."""
assert self.llm is not None, (
"Attempting to reset prefix cache with either an uninitialized vLLM or non-model-owner"
)

if self.cfg["vllm_cfg"]["async_engine"]:
raise RuntimeError(
"reset_prefix_cache can only be used with async_engine=False. Use reset_prefix_cache_async instead."
)

self.llm.llm_engine.reset_prefix_cache()
gc.collect()
torch.cuda.empty_cache()

async def reset_prefix_cache_async(self):
"""Async version of reset_prefix_cache."""
assert self.llm is not None, (
"Attempting to reset prefix cache with either an uninitialized vLLM or non-model-owner"
)

if not self.cfg["vllm_cfg"]["async_engine"]:
raise RuntimeError(
"reset_prefix_cache_async can only be used with async_engine=True. Use reset_prefix_cache instead."
)

await self.llm.reset_prefix_cache()
gc.collect()
torch.cuda.empty_cache()

def sleep(self):
"""Put the vLLM engine to sleep."""
assert self.llm is not None, (
Expand Down Expand Up @@ -1305,6 +1378,13 @@ def init_collective(
if not self.worker_group or not self.worker_group.workers:
raise RuntimeError("Worker group is not initialized")

# Choose the appropriate method based on async_engine setting
method_name = (
"init_collective_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "init_collective"
)

# Prepare rank
total_workers = len(self.worker_group.workers)
if self.dp_size == 0:
Expand All @@ -1316,7 +1396,7 @@ def init_collective(

# Send world_size and rank for init collective to all workers
futures = self.worker_group.run_all_workers_multiple_data(
"init_collective",
method_name,
data=rank_prefix_list,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
common_kwargs={"ip": ip, "port": port, "world_size": world_size},
Expand Down Expand Up @@ -1558,12 +1638,16 @@ def finish_generation(self, *args: Any, **kwargs: Any) -> bool:
try:
# Choose the appropriate method based on setting
# non-colocated only needs reset prefix cache, no need to sleep.
if not self.cfg["colocated"]["enabled"]:
method_name = "reset_prefix_cache"
else:
if self.cfg["colocated"]["enabled"]:
method_name = (
"sleep_async" if self.cfg["vllm_cfg"]["async_engine"] else "sleep"
)
else:
method_name = (
"reset_prefix_cache_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "reset_prefix_cache"
)
# Use run_all_workers_single_data for methods that don't need data
futures = self.worker_group.run_all_workers_single_data(
method_name,
Expand Down Expand Up @@ -1636,9 +1720,16 @@ def update_weights_from_collective(
if not self.worker_group or not self.worker_group.workers:
raise RuntimeError("Worker group is not initialized")

# Choose the appropriate method based on async_engine setting
method_name = (
"update_weights_from_collective_async"
if self.cfg["vllm_cfg"]["async_engine"]
else "update_weights_from_collective"
)

# Use run_all_workers_single_data to send data to all workers
futures = self.worker_group.run_all_workers_single_data(
"update_weights_from_collective",
method_name,
data=info,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
)
Expand Down
19 changes: 14 additions & 5 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,13 @@ def test_vllm_policy_generation(policy, test_input_data, tokenizer):
)


async def _generate_async(vllm_policy, tokenizer, test_input_data):
async def _generate_async(vllm_policy, tokenizer, test_input_data, greedy=False):
collected_indexed_outputs = []
# generate_async is restricted to handle only single samples
input_generator = test_input_data.make_microbatch_iterator(microbatch_size=1)
for single_item_input in input_generator:
async for original_idx, single_item_output in vllm_policy.generate_async(
single_item_input
single_item_input, greedy=greedy
):
collected_indexed_outputs.append((original_idx, single_item_output))

Expand Down Expand Up @@ -691,7 +691,7 @@ async def test_vllm_generation_with_hf_training(
print("Using vLLM policy for fast generation...")
if async_engine:
generation_results = await _generate_async(
vllm_policy, tokenizer, test_input_data
vllm_policy, tokenizer, test_input_data, greedy=True
)
else:
generation_results = vllm_policy.generate(test_input_data, greedy=True)
Expand Down Expand Up @@ -1174,11 +1174,14 @@ def test_vllm_non_divisible_batch_handling(policy):
)


def test_vllm_refit_non_collocated_handles_update(
@pytest.mark.asyncio
@pytest.mark.parametrize("async_engine", [True, False])
async def test_vllm_refit_non_collocated_update_weights(
policy_cluster_separate,
generation_cluster_separate,
tokenizer,
test_input_data,
async_engine,
):
if (
policy_cluster_separate.num_gpus_per_node < 1
Expand All @@ -1197,6 +1200,7 @@ def test_vllm_refit_non_collocated_handles_update(
# Create VllmGeneration policy on its own cluster
vllm_config = deepcopy(basic_vllm_test_config)
vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True)
vllm_config["vllm_cfg"]["async_engine"] = async_engine
vllm_config["vllm_cfg"]["tensor_parallel_size"] = 1
vllm_config["colocated"]["enabled"] = False
vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config)
Expand All @@ -1213,7 +1217,12 @@ def test_vllm_refit_non_collocated_handles_update(
)

# test generate
outputs = vllm_generation.generate(test_input_data, greedy=True)
if async_engine:
outputs = await _generate_async(
vllm_generation, tokenizer, test_input_data, greedy=True
)
else:
outputs = vllm_generation.generate(test_input_data, greedy=True)
output_ids = outputs["output_ids"]
generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
assert generated_texts == [
Expand Down
Loading