diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index 75659e55b8..fe71e56b02 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -427,11 +427,16 @@ def setup( if not colocated_inference: ip, port = train_cluster.get_master_address_and_port() print(f"Using ip: {ip}, port: {port} for collective communication", flush=True) + train_world_size = train_cluster.world_size() # inference cluster + head node of the train cluster - world_size = inference_nodes * inference_gpus_per_node + 1 + world_size = train_world_size + inference_nodes * inference_gpus_per_node # init collective - futures_train = student_policy.init_collective(ip, port, world_size) - futures_inference = student_generation.init_collective(ip, port, world_size) # type: ignore + futures_train = student_policy.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) + futures_inference = student_generation.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) # type: ignore # wait for all futures to complete ray.get(futures_train + futures_inference) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 1b10cd6375..971e5f7f52 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -435,11 +435,17 @@ def setup( if not colocated_inference: ip, port = train_cluster.get_master_address_and_port() print(f"Using ip: {ip}, port: {port} for collective communication", flush=True) - # inference cluster + head node of the train cluster - world_size = inference_nodes * inference_gpus_per_node + 1 + # world includes all training workers and all inference workers + train_world_size = train_cluster.world_size() + inference_world_size = inference_nodes * inference_gpus_per_node + world_size = train_world_size + inference_world_size # init collective - futures_train = policy.init_collective(ip, port, world_size) - futures_inference = policy_generation.init_collective(ip, port, world_size) # type: ignore + futures_train = policy.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) + futures_inference = policy_generation.init_collective( + ip, port, world_size, train_world_size=train_world_size + ) # type: ignore # wait for all futures to complete ray.get(futures_train + futures_inference) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 5c3b125514..895506e4b4 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -32,14 +32,20 @@ class VllmInternalWorkerExtension: def init_collective( - self, rank_prefix: int, ip: str, port: int, world_size: int + self, + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int, ) -> None: """Initialize the collective communication.""" from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup local_rank = torch.distributed.get_rank() - rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster + # Place vLLM ranks after all training ranks so all training workers can join + rank = train_world_size + rank_prefix + local_rank pg = StatelessProcessGroup.create( host=ip, port=port, rank=rank, world_size=world_size diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index f67cbea41a..48ed02425f 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -368,7 +368,7 @@ def _post_init(self): return results def init_collective( - self, ip: str, port: int, world_size: int + self, ip: str, port: int, world_size: int, *, train_world_size: int ) -> list[ray.ObjectRef]: """Initialize the collective communication.""" if not self.worker_group or not self.worker_group.workers: @@ -395,7 +395,12 @@ def init_collective( method_name, rank_prefix=rank_prefix_list, run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], - common_kwargs={"ip": ip, "port": port, "world_size": world_size}, + common_kwargs={ + "ip": ip, + "port": port, + "world_size": world_size, + "train_world_size": train_world_size, + }, ) # this function should co-work with lm_policy, so we should wait for all futures to complete outside diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 11f7716c94..b17478bc0f 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -477,7 +477,12 @@ def post_init(self): self.vllm_device_ids = self.report_device_id() def init_collective( - self, rank_prefix: int, ip: str, port: int, world_size: int + self, + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int, ) -> None: self.llm.collective_rpc( "init_collective", @@ -486,6 +491,7 @@ def init_collective( ip, port, world_size, + train_world_size, ), ) diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index c145cf074d..2e7bc9d082 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -393,7 +393,12 @@ def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]": return thread, base_url, server async def init_collective_async( - self, rank_prefix: int, ip: str, port: int, world_size: int + self, + rank_prefix: int, + ip: str, + port: int, + world_size: int, + train_world_size: int, ) -> None: await self.llm.collective_rpc( "init_collective", @@ -402,6 +407,7 @@ async def init_collective_async( ip, port, world_size, + train_world_size, ), ) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 1a8ef6547c..cfe524be8d 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -501,17 +501,18 @@ def train_context(cp_context: Optional[Generator[None, None, None]] = None): yield - def init_collective(self, ip: str, port: int, world_size: int) -> None: + def init_collective( + self, ip: str, port: int, world_size: int, *, train_world_size: int + ) -> None: """Initialize the collective communication.""" from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup - if self.rank == 0: - pg = StatelessProcessGroup.create( - host=ip, port=port, rank=0, world_size=world_size - ) - device = torch.cuda.current_device() - self.model_update_group = PyNcclCommunicator(pg, device=device) + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=self.rank, world_size=world_size + ) + device = torch.cuda.current_device() + self.model_update_group = PyNcclCommunicator(pg, device=device) def is_alive(self) -> bool: return True @@ -1808,9 +1809,8 @@ def broadcast_weights_for_collective(self) -> None: for _, tensor in self.model.state_dict().items(): if isinstance(tensor, DTensor): tensor = tensor.full_tensor() - if self.rank == 0: - tensor = tensor.to(self.dtype, non_blocking=True) - self.model_update_group.broadcast(tensor.data, src=0) + tensor = tensor.to(self.dtype, non_blocking=True) + self.model_update_group.broadcast(tensor.data, src=0) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 6d623af8a7..ee5cea0b5d 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -459,17 +459,17 @@ def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: logits.div_(self.cfg["generation"]["temperature"]) return logits - def init_collective(self, ip: str, port: int, world_size: int) -> None: - """Initialize the collective communication.""" + def init_collective( + self, ip: str, port: int, world_size: int, *, train_world_size: int + ) -> None: from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup - if self.rank == 0: - pg = StatelessProcessGroup.create( - host=ip, port=port, rank=0, world_size=world_size - ) - device = torch.cuda.current_device() - self.model_update_group = PyNcclCommunicator(pg, device=device) + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=self.rank, world_size=world_size + ) + device = torch.cuda.current_device() + self.model_update_group = PyNcclCommunicator(pg, device=device) def is_alive(self) -> bool: return True @@ -1770,9 +1770,8 @@ def broadcast_weights_for_collective(self) -> None: for _, tensor in self.model.state_dict().items(): if isinstance(tensor, DTensor): tensor = tensor.full_tensor() - if self.rank == 0: - tensor = tensor.to(self.dtype, non_blocking=True) - self.model_update_group.broadcast(tensor.data, src=0) + tensor = tensor.to(self.dtype, non_blocking=True) + self.model_update_group.broadcast(tensor.data, src=0) # Manually move model to cpu for cpu offload case # cpu offload needs model on CPU before model forward diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 3980805bee..fef56f0ea2 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -141,7 +141,7 @@ def shutdown(self) -> bool: class ColocatablePolicyInterface(PolicyInterface): @abstractmethod def init_collective( - self, ip: str, port: int, world_size: int + self, ip: str, port: int, world_size: int, *, train_world_size: int ) -> list[ray.ObjectRef]: pass diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 9ee977458a..5d08003ad9 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -234,11 +234,15 @@ def __init__( self.cfg = config def init_collective( - self, ip: str, port: int, world_size: int + self, ip: str, port: int, world_size: int, *, train_world_size: int ) -> list[ray.ObjectRef]: """Initialize the collective communication.""" futures = self.worker_group.run_all_workers_single_data( - "init_collective", ip=ip, port=port, world_size=world_size + "init_collective", + ip=ip, + port=port, + world_size=world_size, + train_world_size=train_world_size, ) # this function should co-work with vllm, so we should wait for all futures to complete outside return futures diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 620d02a41b..326ae9fe61 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -820,17 +820,20 @@ def __init__( ## used for streaming update inference engine weights self._held_gather_buffer = None - def init_collective(self, ip: str, port: int, world_size: int) -> None: + def init_collective( + self, ip: str, port: int, world_size: int, *, train_world_size: int + ) -> None: """Initialize the collective communication.""" from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup - if self.rank == 0: - pg = StatelessProcessGroup.create( - host=ip, port=port, rank=0, world_size=world_size - ) - device = torch.cuda.current_device() - self.model_update_group = PyNcclCommunicator(pg, device=device) + # world_size = train_world_size + inference_world_size + # variable train_world_size is used in inference cluster + pg = StatelessProcessGroup.create( + host=ip, port=port, rank=self.rank, world_size=world_size + ) + device = torch.cuda.current_device() + self.model_update_group = PyNcclCommunicator(pg, device=device) def is_alive(self): return True @@ -1735,10 +1738,9 @@ def broadcast_weights_for_collective(self) -> None: [self.model], show_progress=False, ) - # broadcast from train rank0 worker to inference workers + # broadcast from train rank 0 to all other ranks (training and inference) for _, tensor in hf_params_generator: - if self.rank == 0: - self.model_update_group.broadcast(tensor, src=0) + self.model_update_group.broadcast(tensor, src=0) def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index a5cbd324e5..d51e9c1eed 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -415,6 +415,115 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node(): setup(master_config, tokenizer, dataset, None) +def test_distillation_setup_non_colocated_smoke(monkeypatch): + """Smoke test: calling setup with a non-colocated config should succeed.""" + from unittest.mock import MagicMock, patch + + import nemo_rl.algorithms.distillation as distil_mod + + # Single node cluster; inference uses a subset of GPUs on same node + master_config = { + "policy": { + "generation": { + "backend": "vllm", + "colocated": { + "enabled": False, + "resources": { + "gpus_per_node": 8, # inference on 8 GPU + "num_nodes": 1, + }, + }, + }, + "dtensor_cfg": { + "enabled": False, + }, + "model_name": "test-policy", + }, + "teacher": { + "model_name": "test-teacher", + "dtensor_cfg": { + "enabled": False, + }, + }, + "loss_fn": { + "kl_type": "forward", + "mixed_kl_weight": 0.5, + "zero_outside_topk": False, + }, + "distillation": { + "seed": 42, + "topk_logits_k": 64, + "num_prompts_per_step": 1, + "val_period": 0, + "val_at_start": False, + }, + "data": {"shuffle": False}, + "logger": {}, + "checkpointing": {}, + "cluster": {"num_nodes": 2, "gpus_per_node": 8}, + } + + tokenizer = MagicMock() + dataset = MagicMock() + dataset.__len__ = MagicMock(return_value=1) + + # Skip tokenizer/vocab equality check inside setup + monkeypatch.setenv("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK", "1") + + ip_port = ("127.0.0.1", 12345) + + class DummyCluster: + def __init__(self, *args, **kwargs): + pass + + def world_size(self): + return 1 + + def get_master_address_and_port(self): + return ip_port + + class DummyPolicy: + def __init__(self, *args, **kwargs): + pass + + def prepare_refit_info(self): + return {} + + def init_collective(self, *args, **kwargs): + return [MagicMock()] + + class DummyVllmGeneration: + def __init__(self, *args, **kwargs): + pass + + def finish_generation(self): + return None + + def prepare_refit_info(self, *args, **kwargs): + return None + + def init_collective(self, *args, **kwargs): + return [MagicMock()] + + with ( + patch.object(distil_mod, "RayVirtualCluster", DummyCluster), + patch.object(distil_mod, "Logger"), + patch.object(distil_mod, "CheckpointManager") as mock_ckpt_mgr, + patch.object(distil_mod, "StatefulDataLoader"), + patch.object(distil_mod, "Policy", DummyPolicy), + patch.object(distil_mod, "VllmGeneration", DummyVllmGeneration), + patch.object(distil_mod, "ray") as mock_ray, + ): + mock_ckpt_mgr.return_value.get_latest_checkpoint_path.return_value = None + mock_ray.get = MagicMock(return_value=None) + + # Should not raise + result = distil_mod.setup(master_config, tokenizer, dataset, None) + + # Basic shape check of returned tuple + assert isinstance(result, tuple) + + def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node(): """Test that non-colocated inference requires explicit gpus_per_node when cluster.num_nodes>1.""" from unittest.mock import MagicMock, patch diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 140efca56b..6690cbea2a 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -944,8 +944,15 @@ async def test_vllm_generation_with_hf_training_non_colocated( # Refit # initialize collective communication for update weights ip, port = policy_cluster_separate.get_master_address_and_port() - futures_train = lm_policy.init_collective(ip, port, world_size=2) - futures_inference = vllm_policy.init_collective(ip, port, world_size=2) + train_world_size = policy_cluster_separate.world_size() + inference_world_size = generation_cluster_separate.world_size() + world_size = train_world_size + inference_world_size + futures_train = lm_policy.init_collective( + ip, port, world_size=world_size, train_world_size=train_world_size + ) + futures_inference = vllm_policy.init_collective( + ip, port, world_size=world_size, train_world_size=train_world_size + ) ray.get(futures_train + futures_inference) # prepare refit info @@ -1747,9 +1754,15 @@ async def test_vllm_refit_non_colocated_update_weights( # initialize collective communication for update weights ip, port = policy_cluster_separate.get_master_address_and_port() - world_size = tensor_parallel_size + 1 - futures_train = lm_policy.init_collective(ip, port, world_size=world_size) - futures_inference = vllm_generation.init_collective(ip, port, world_size=world_size) + train_world_size = policy_cluster_separate.world_size() + inference_world_size = generation_cluster_separate.world_size() + world_size = train_world_size + inference_world_size + futures_train = lm_policy.init_collective( + ip, port, world_size=world_size, train_world_size=train_world_size + ) + futures_inference = vllm_generation.init_collective( + ip, port, world_size=world_size, train_world_size=train_world_size + ) ray.get(futures_train + futures_inference) # prepare refit info