Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,13 +1363,21 @@ def __init__(
"tensor_parallel"
) * self.sharding_annotations.get_axis_size("pipeline_parallel")

# non-colocated needs to use PACK strategy to avoid uneven node_bundles
# e.g. assuming we use 3 nodes with 8GPUs, 2 nodes for train and 1 node for inference.
# if we use SPREAD, then the node bundles will be something like 0: [0,3,6] 1: [1,4,7] 2: [2,5], which is not correct.
strategy = None if self.cfg["colocated"]["enabled"] else "PACK"

# Determine if we need cross-node model parallelism
needs_cross_node_parallelism = (
self.model_parallel_size > cluster.num_gpus_per_node
)

# Initialize placement groups with the appropriate mode
cluster._init_placement_groups(use_unified_pg=needs_cross_node_parallelism)
cluster._init_placement_groups(
strategy=strategy,
use_unified_pg=needs_cross_node_parallelism,
)

# Create worker builder for VllmGenerationWorker
worker_builder = RayWorkerBuilder(
Expand All @@ -1381,7 +1389,7 @@ def __init__(
# See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details.
env_vars = {}
if not self.cfg["colocated"]["enabled"]:
os.environ["NCCL_CUMEM_ENABLE"] = "1"
env_vars["NCCL_CUMEM_ENABLE"] = "1"

# Check if we need parallelism-aware worker group creation
if self.model_parallel_size > 1:
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def prepare_refit_info(

MegatronPolicyWorker:
colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)}
non-colocated inference: not implemented yet
non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)}
"""
self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored

Expand Down
53 changes: 47 additions & 6 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,15 @@ def __init__(
pre_init_communication_queue: Queue,
**kwargs: Any,
):
self.is_generation_colocated = None
if "generation" in config and config["generation"] is not None:
self.is_generation_colocated = config["generation"]["colocated"]["enabled"]

# Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator.
# See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details.
if not self.is_generation_colocated:
os.environ["NCCL_CUMEM_ENABLE"] = "1"

self.cfg = config
dtype_map = {
"float32": torch.float32,
Expand Down Expand Up @@ -416,7 +425,8 @@ def __init__(
# Ensure clean slate before import
destroy_parallel_state()

if get_rank_safe() == 0:
self.rank = get_rank_safe()
if self.rank == 0:
if pt_checkpoint_exists:
print(
f"Checkpoint already exists at {pretrained_path}. Skipping import."
Expand Down Expand Up @@ -725,6 +735,18 @@ def __init__(
## used for streaming update inference engine weights
self._held_gather_buffer = None

def init_collective(self, ip: str, port: int, world_size: int) -> None:
"""Initialize the collective communication."""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

if self.rank == 0:
pg = StatelessProcessGroup.create(
host=ip, port=port, rank=0, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)

def is_alive(self):
return True

Expand Down Expand Up @@ -1409,11 +1431,11 @@ def prepare_refit_info(self) -> None:
)
# collect tensor metadata
for name, tensor in gathered_hf_params.items():
refit_param_info_hf[name] = (
tensor.shape,
tensor.dtype,
tensor.numel(),
)
if self.is_generation_colocated:
metadata = (tensor.shape, tensor.dtype, tensor.numel())
else:
metadata = (tensor.shape, tensor.dtype)
refit_param_info_hf[name] = metadata

return refit_param_info_hf

Expand Down Expand Up @@ -1526,6 +1548,25 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]:

return {device_uuid: serialized}

@torch.no_grad()
def broadcast_weights_for_collective(self) -> None:
"""Broadcast the weights for collective communication."""
for key, _ in self.refit_param_info_mcore:
# gather megatron params
gathered_megatron_params = gather_params(
self.model,
[key],
key_to_global_keys=self.local_key_to_global_keys,
)
# convert to hf params
gathered_hf_params = self.megatron_to_hf_converter.convert(
gathered_megatron_params, self.model.config
)
# broadcast from train rank0 worker to inference workers
if self.rank == 0:
for _, tensor in gathered_hf_params.items():
self.model_update_group.broadcast(tensor, src=0)

def prepare_for_lp_inference(self):
self.model = self.move_model(self.model, "cuda", move_grads=False)
self.model.eval()
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ mcore = [
"transformer-engine[pytorch]==2.3.0",
"megatron-core",
"nemo-tron",
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
"vllm==0.10.0",
# Flash-attn version should be selected to satisfy both TE + vLLM requirements (xformers in particular)
# https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108
# https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76
Expand Down
66 changes: 46 additions & 20 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@
from nemo_rl.algorithms.loss_functions import NLLLoss
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import (
RayVirtualCluster,
_get_node_ip_and_free_port,
)
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration
from nemo_rl.models.policy import PolicyConfig
Expand Down Expand Up @@ -1200,12 +1197,14 @@ def test_vllm_non_divisible_batch_handling(policy):
@pytest.mark.asyncio
@pytest.mark.parametrize("async_engine", [True, False])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2])
async def test_vllm_refit_non_collocated_update_weights(
@pytest.mark.parametrize("policy_type", ["dtensor", "megatron"])
async def test_vllm_refit_non_colocated_update_weights(
policy_cluster_separate,
tokenizer,
test_input_data,
async_engine,
tensor_parallel_size,
policy_type,
):
# Skip tensor_parallel_size == 2 until we have resources in CI
if tensor_parallel_size == 2:
Expand All @@ -1223,33 +1222,49 @@ async def test_vllm_refit_non_collocated_update_weights(
"Test requires at least two GPUs to run policies on separate clusters."
)

# Create Policy on its own cluster
dtensor_config = deepcopy(basic_dtensor_test_config)
dtensor_config["generation"]["colocated"]["enabled"] = False
lm_policy = Policy(policy_cluster_separate, dtensor_config, tokenizer)
# Get policy config
if policy_type == "dtensor":
lm_config = deepcopy(basic_dtensor_test_config)
else:
assert policy_type == "megatron"
lm_config = get_basic_megatron_test_config(tp=1, pp=1, precision="float32")
lm_config["generation"]["colocated"]["enabled"] = False

# Create VllmGeneration policy on its own cluster
# Get vllm config
vllm_config = deepcopy(basic_vllm_test_config)
vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True)
vllm_config["vllm_cfg"]["async_engine"] = async_engine
vllm_config["vllm_cfg"]["tensor_parallel_size"] = tensor_parallel_size
vllm_config["colocated"]["enabled"] = False

# Megatron config with Qwen2.5-0.5B
if policy_type == "megatron":
model_name = "Qwen/Qwen2.5-0.5B"
tokenizer = get_tokenizer({"name": model_name})

lm_config["model_name"] = model_name
lm_config["tokenizer"]["name"] = model_name

vllm_config["model_name"] = model_name
vllm_config["tokenizer"]["name"] = model_name

# Create Policy and VllmGeneration
lm_policy = Policy(policy_cluster_separate, lm_config, tokenizer)
vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config)

# initialize collective communication for update weights
ip, port = ray.get(_get_node_ip_and_free_port.remote())
futures_train = lm_policy.init_collective(ip, port, world_size=2)
futures_inference = vllm_generation.init_collective(ip, port, world_size=2)
ip, port = policy_cluster_separate.get_master_address_and_port()
world_size = tensor_parallel_size + 1
futures_train = lm_policy.init_collective(ip, port, world_size=world_size)
futures_inference = vllm_generation.init_collective(ip, port, world_size=world_size)
ray.get(futures_train + futures_inference)

# prepare refit info
state_dict_info = lm_policy.prepare_refit_info()
vllm_generation.prepare_refit_info(state_dict_info)

print("refitting vllm policy...")
refit_policy_generation(
lm_policy, vllm_generation, vllm_config["colocated"]["enabled"]
)
refit_policy_generation(lm_policy, vllm_generation, False)

# test generate
if async_engine:
Expand All @@ -1258,12 +1273,23 @@ async def test_vllm_refit_non_collocated_update_weights(
)
else:
outputs = vllm_generation.generate(test_input_data, greedy=True)

output_ids = outputs["output_ids"]
generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
assert generated_texts == [
"Hello, my name is Lina. I'm",
"The capital of France is Paris. The capital of",
], "Output should be the same as the expected output"

if policy_type == "dtensor":
expected_texts = [
"Hello, my name is Lina. I'm",
"The capital of France is Paris. The capital of",
]
else:
expected_texts = [
"Hello, my name is Kaitlin and I",
"The capital of France is Paris. It is the",
]
assert generated_texts == expected_texts, (
"Output should be the same as the expected output"
)

# Clean up
vllm_generation.shutdown()
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.