diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 59a3cc5eed..8c9b6b99c9 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1363,13 +1363,21 @@ def __init__( "tensor_parallel" ) * self.sharding_annotations.get_axis_size("pipeline_parallel") + # non-colocated needs to use PACK strategy to avoid uneven node_bundles + # e.g. assuming we use 3 nodes with 8GPUs, 2 nodes for train and 1 node for inference. + # if we use SPREAD, then the node bundles will be something like 0: [0,3,6] 1: [1,4,7] 2: [2,5], which is not correct. + strategy = None if self.cfg["colocated"]["enabled"] else "PACK" + # Determine if we need cross-node model parallelism needs_cross_node_parallelism = ( self.model_parallel_size > cluster.num_gpus_per_node ) # Initialize placement groups with the appropriate mode - cluster._init_placement_groups(use_unified_pg=needs_cross_node_parallelism) + cluster._init_placement_groups( + strategy=strategy, + use_unified_pg=needs_cross_node_parallelism, + ) # Create worker builder for VllmGenerationWorker worker_builder = RayWorkerBuilder( @@ -1381,7 +1389,7 @@ def __init__( # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. env_vars = {} if not self.cfg["colocated"]["enabled"]: - os.environ["NCCL_CUMEM_ENABLE"] = "1" + env_vars["NCCL_CUMEM_ENABLE"] = "1" # Check if we need parallelism-aware worker group creation if self.model_parallel_size > 1: diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 7ede65f3c9..1861f7643d 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -62,7 +62,7 @@ def prepare_refit_info( MegatronPolicyWorker: colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype, numel)} - non-colocated inference: not implemented yet + non-colocated inference: state_dict_info is a dict of {tensor_name: (shape, dtype)} """ self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 35c18eb701..07c9594322 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -378,6 +378,15 @@ def __init__( pre_init_communication_queue: Queue, **kwargs: Any, ): + self.is_generation_colocated = None + if "generation" in config and config["generation"] is not None: + self.is_generation_colocated = config["generation"]["colocated"]["enabled"] + + # Explicitly set NCCL_CUMEM_ENABLE to 1 to avoid the P2P initialization error for PyNCCLCommunicator. + # See https://github.com/NVIDIA-NeMo/RL/issues/564 for more details. + if not self.is_generation_colocated: + os.environ["NCCL_CUMEM_ENABLE"] = "1" + self.cfg = config dtype_map = { "float32": torch.float32, @@ -416,7 +425,8 @@ def __init__( # Ensure clean slate before import destroy_parallel_state() - if get_rank_safe() == 0: + self.rank = get_rank_safe() + if self.rank == 0: if pt_checkpoint_exists: print( f"Checkpoint already exists at {pretrained_path}. Skipping import." @@ -725,6 +735,18 @@ def __init__( ## used for streaming update inference engine weights self._held_gather_buffer = None + def init_collective(self, ip: str, port: int, world_size: int) -> None: + """Initialize the collective communication.""" + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + if self.rank == 0: + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=0, world_size=world_size + ) + device = torch.cuda.current_device() + self.model_update_group = PyNcclCommunicator(pg, device=device) + def is_alive(self): return True @@ -1409,11 +1431,11 @@ def prepare_refit_info(self) -> None: ) # collect tensor metadata for name, tensor in gathered_hf_params.items(): - refit_param_info_hf[name] = ( - tensor.shape, - tensor.dtype, - tensor.numel(), - ) + if self.is_generation_colocated: + metadata = (tensor.shape, tensor.dtype, tensor.numel()) + else: + metadata = (tensor.shape, tensor.dtype) + refit_param_info_hf[name] = metadata return refit_param_info_hf @@ -1526,6 +1548,25 @@ def get_weights_ipc_handles(self, *, keys: list[str]) -> dict[str, Any]: return {device_uuid: serialized} + @torch.no_grad() + def broadcast_weights_for_collective(self) -> None: + """Broadcast the weights for collective communication.""" + for key, _ in self.refit_param_info_mcore: + # gather megatron params + gathered_megatron_params = gather_params( + self.model, + [key], + key_to_global_keys=self.local_key_to_global_keys, + ) + # convert to hf params + gathered_hf_params = self.megatron_to_hf_converter.convert( + gathered_megatron_params, self.model.config + ) + # broadcast from train rank0 worker to inference workers + if self.rank == 0: + for _, tensor in gathered_hf_params.items(): + self.model_update_group.broadcast(tensor, src=0) + def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) self.model.eval() diff --git a/pyproject.toml b/pyproject.toml index f0ce1b1f2a..cb7b6f5227 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,8 @@ mcore = [ "transformer-engine[pytorch]==2.3.0", "megatron-core", "nemo-tron", + # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved + "vllm==0.10.0", # Flash-attn version should be selected to satisfy both TE + vLLM requirements (xformers in particular) # https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108 # https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76 diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index eae44d8dfe..901073c3e1 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -23,10 +23,7 @@ from nemo_rl.algorithms.loss_functions import NLLLoss from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.virtual_cluster import ( - RayVirtualCluster, - _get_node_ip_and_free_port, -) +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig @@ -1200,12 +1197,14 @@ def test_vllm_non_divisible_batch_handling(policy): @pytest.mark.asyncio @pytest.mark.parametrize("async_engine", [True, False]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) -async def test_vllm_refit_non_collocated_update_weights( +@pytest.mark.parametrize("policy_type", ["dtensor", "megatron"]) +async def test_vllm_refit_non_colocated_update_weights( policy_cluster_separate, tokenizer, test_input_data, async_engine, tensor_parallel_size, + policy_type, ): # Skip tensor_parallel_size == 2 until we have resources in CI if tensor_parallel_size == 2: @@ -1223,23 +1222,41 @@ async def test_vllm_refit_non_collocated_update_weights( "Test requires at least two GPUs to run policies on separate clusters." ) - # Create Policy on its own cluster - dtensor_config = deepcopy(basic_dtensor_test_config) - dtensor_config["generation"]["colocated"]["enabled"] = False - lm_policy = Policy(policy_cluster_separate, dtensor_config, tokenizer) + # Get policy config + if policy_type == "dtensor": + lm_config = deepcopy(basic_dtensor_test_config) + else: + assert policy_type == "megatron" + lm_config = get_basic_megatron_test_config(tp=1, pp=1, precision="float32") + lm_config["generation"]["colocated"]["enabled"] = False - # Create VllmGeneration policy on its own cluster + # Get vllm config vllm_config = deepcopy(basic_vllm_test_config) vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_config["vllm_cfg"]["async_engine"] = async_engine vllm_config["vllm_cfg"]["tensor_parallel_size"] = tensor_parallel_size vllm_config["colocated"]["enabled"] = False + + # Megatron config with Qwen2.5-0.5B + if policy_type == "megatron": + model_name = "Qwen/Qwen2.5-0.5B" + tokenizer = get_tokenizer({"name": model_name}) + + lm_config["model_name"] = model_name + lm_config["tokenizer"]["name"] = model_name + + vllm_config["model_name"] = model_name + vllm_config["tokenizer"]["name"] = model_name + + # Create Policy and VllmGeneration + lm_policy = Policy(policy_cluster_separate, lm_config, tokenizer) vllm_generation = VllmGeneration(generation_cluster_separate, vllm_config) # initialize collective communication for update weights - ip, port = ray.get(_get_node_ip_and_free_port.remote()) - futures_train = lm_policy.init_collective(ip, port, world_size=2) - futures_inference = vllm_generation.init_collective(ip, port, world_size=2) + ip, port = policy_cluster_separate.get_master_address_and_port() + world_size = tensor_parallel_size + 1 + futures_train = lm_policy.init_collective(ip, port, world_size=world_size) + futures_inference = vllm_generation.init_collective(ip, port, world_size=world_size) ray.get(futures_train + futures_inference) # prepare refit info @@ -1247,9 +1264,7 @@ async def test_vllm_refit_non_collocated_update_weights( vllm_generation.prepare_refit_info(state_dict_info) print("refitting vllm policy...") - refit_policy_generation( - lm_policy, vllm_generation, vllm_config["colocated"]["enabled"] - ) + refit_policy_generation(lm_policy, vllm_generation, False) # test generate if async_engine: @@ -1258,12 +1273,23 @@ async def test_vllm_refit_non_collocated_update_weights( ) else: outputs = vllm_generation.generate(test_input_data, greedy=True) + output_ids = outputs["output_ids"] generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) - assert generated_texts == [ - "Hello, my name is Lina. I'm", - "The capital of France is Paris. The capital of", - ], "Output should be the same as the expected output" + + if policy_type == "dtensor": + expected_texts = [ + "Hello, my name is Lina. I'm", + "The capital of France is Paris. The capital of", + ] + else: + expected_texts = [ + "Hello, my name is Kaitlin and I", + "The capital of France is Paris. It is the", + ] + assert generated_texts == expected_texts, ( + "Output should be the same as the expected output" + ) # Clean up vllm_generation.shutdown() diff --git a/uv.lock b/uv.lock index 9688b33fb9..c2a8c6cfe9 100644 --- a/uv.lock +++ b/uv.lock @@ -2549,6 +2549,7 @@ mcore = [ { name = "megatron-core" }, { name = "nemo-tron" }, { name = "transformer-engine", extra = ["pytorch"] }, + { name = "vllm" }, ] vllm = [ { name = "flash-attn" }, @@ -2619,6 +2620,7 @@ requires-dist = [ { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'mcore'", specifier = "==2.3.0" }, { name = "transformers", specifier = ">=4.51.0,<4.54.0" }, { name = "triton", index = "https://download.pytorch.org/whl/cu128" }, + { name = "vllm", marker = "extra == 'mcore'", specifier = "==0.10.0" }, { name = "vllm", marker = "extra == 'vllm'", specifier = "==0.10.0" }, { name = "wandb" }, ]