Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions nemo_rl/algorithms/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,16 @@ def setup(
if not colocated_inference:
ip, port = train_cluster.get_master_address_and_port()
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
train_world_size = train_cluster.world_size()
# inference cluster + head node of the train cluster
world_size = inference_nodes * inference_gpus_per_node + 1
world_size = train_world_size + inference_nodes * inference_gpus_per_node
# init collective
futures_train = student_policy.init_collective(ip, port, world_size)
futures_inference = student_generation.init_collective(ip, port, world_size) # type: ignore
futures_train = student_policy.init_collective(
ip, port, world_size, train_world_size=train_world_size
)
futures_inference = student_generation.init_collective(
ip, port, world_size, train_world_size=train_world_size
) # type: ignore
# wait for all futures to complete
ray.get(futures_train + futures_inference)

Expand Down
14 changes: 10 additions & 4 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,17 @@ def setup(
if not colocated_inference:
ip, port = train_cluster.get_master_address_and_port()
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
# inference cluster + head node of the train cluster
world_size = inference_nodes * inference_gpus_per_node + 1
# world includes all training workers and all inference workers
train_world_size = train_cluster.world_size()
inference_world_size = inference_nodes * inference_gpus_per_node
world_size = train_world_size + inference_world_size
# init collective
futures_train = policy.init_collective(ip, port, world_size)
futures_inference = policy_generation.init_collective(ip, port, world_size) # type: ignore
futures_train = policy.init_collective(
ip, port, world_size, train_world_size=train_world_size
)
futures_inference = policy_generation.init_collective(
ip, port, world_size, train_world_size=train_world_size
) # type: ignore
# wait for all futures to complete
ray.get(futures_train + futures_inference)

Expand Down
10 changes: 8 additions & 2 deletions nemo_rl/models/generation/vllm/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@

class VllmInternalWorkerExtension:
def init_collective(
self, rank_prefix: int, ip: str, port: int, world_size: int
self,
rank_prefix: int,
ip: str,
port: int,
world_size: int,
train_world_size: int,
) -> None:
"""Initialize the collective communication."""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

local_rank = torch.distributed.get_rank()
rank = rank_prefix + local_rank + 1 # 1 is the head node of the train cluster
# Place vLLM ranks after all training ranks so all training workers can join
rank = train_world_size + rank_prefix + local_rank

pg = StatelessProcessGroup.create(
host=ip, port=port, rank=rank, world_size=world_size
Expand Down
9 changes: 7 additions & 2 deletions nemo_rl/models/generation/vllm/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _post_init(self):
return results

def init_collective(
self, ip: str, port: int, world_size: int
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> list[ray.ObjectRef]:
"""Initialize the collective communication."""
if not self.worker_group or not self.worker_group.workers:
Expand All @@ -395,7 +395,12 @@ def init_collective(
method_name,
rank_prefix=rank_prefix_list,
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"],
common_kwargs={"ip": ip, "port": port, "world_size": world_size},
common_kwargs={
"ip": ip,
"port": port,
"world_size": world_size,
"train_world_size": train_world_size,
},
)

# this function should co-work with lm_policy, so we should wait for all futures to complete outside
Expand Down
8 changes: 7 additions & 1 deletion nemo_rl/models/generation/vllm/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,12 @@ def post_init(self):
self.vllm_device_ids = self.report_device_id()

def init_collective(
self, rank_prefix: int, ip: str, port: int, world_size: int
self,
rank_prefix: int,
ip: str,
port: int,
world_size: int,
train_world_size: int,
) -> None:
self.llm.collective_rpc(
"init_collective",
Expand All @@ -486,6 +491,7 @@ def init_collective(
ip,
port,
world_size,
train_world_size,
),
)

Expand Down
8 changes: 7 additions & 1 deletion nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,12 @@ def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]":
return thread, base_url, server

async def init_collective_async(
self, rank_prefix: int, ip: str, port: int, world_size: int
self,
rank_prefix: int,
ip: str,
port: int,
world_size: int,
train_world_size: int,
) -> None:
await self.llm.collective_rpc(
"init_collective",
Expand All @@ -402,6 +407,7 @@ async def init_collective_async(
ip,
port,
world_size,
train_world_size,
),
)

Expand Down
20 changes: 10 additions & 10 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,17 +501,18 @@ def train_context(cp_context: Optional[Generator[None, None, None]] = None):

yield

def init_collective(self, ip: str, port: int, world_size: int) -> None:
def init_collective(
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> None:
"""Initialize the collective communication."""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

if self.rank == 0:
pg = StatelessProcessGroup.create(
host=ip, port=port, rank=0, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)
pg = StatelessProcessGroup.create(
host=ip, port=port, rank=self.rank, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)

def is_alive(self) -> bool:
return True
Expand Down Expand Up @@ -1808,9 +1809,8 @@ def broadcast_weights_for_collective(self) -> None:
for _, tensor in self.model.state_dict().items():
if isinstance(tensor, DTensor):
tensor = tensor.full_tensor()
if self.rank == 0:
tensor = tensor.to(self.dtype, non_blocking=True)
self.model_update_group.broadcast(tensor.data, src=0)
tensor = tensor.to(self.dtype, non_blocking=True)
self.model_update_group.broadcast(tensor.data, src=0)

# Manually move model to cpu for cpu offload case
# cpu offload needs model on CPU before model forward
Expand Down
21 changes: 10 additions & 11 deletions nemo_rl/models/policy/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,17 +459,17 @@ def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor:
logits.div_(self.cfg["generation"]["temperature"])
return logits

def init_collective(self, ip: str, port: int, world_size: int) -> None:
"""Initialize the collective communication."""
def init_collective(
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> None:
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

if self.rank == 0:
pg = StatelessProcessGroup.create(
host=ip, port=port, rank=0, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)
pg = StatelessProcessGroup.create(
host=ip, port=port, rank=self.rank, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)

def is_alive(self) -> bool:
return True
Expand Down Expand Up @@ -1770,9 +1770,8 @@ def broadcast_weights_for_collective(self) -> None:
for _, tensor in self.model.state_dict().items():
if isinstance(tensor, DTensor):
tensor = tensor.full_tensor()
if self.rank == 0:
tensor = tensor.to(self.dtype, non_blocking=True)
self.model_update_group.broadcast(tensor.data, src=0)
tensor = tensor.to(self.dtype, non_blocking=True)
self.model_update_group.broadcast(tensor.data, src=0)

# Manually move model to cpu for cpu offload case
# cpu offload needs model on CPU before model forward
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/models/policy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def shutdown(self) -> bool:
class ColocatablePolicyInterface(PolicyInterface):
@abstractmethod
def init_collective(
self, ip: str, port: int, world_size: int
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> list[ray.ObjectRef]:
pass

Expand Down
8 changes: 6 additions & 2 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,15 @@ def __init__(
self.cfg = config

def init_collective(
self, ip: str, port: int, world_size: int
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> list[ray.ObjectRef]:
"""Initialize the collective communication."""
futures = self.worker_group.run_all_workers_single_data(
"init_collective", ip=ip, port=port, world_size=world_size
"init_collective",
ip=ip,
port=port,
world_size=world_size,
train_world_size=train_world_size,
)
# this function should co-work with vllm, so we should wait for all futures to complete outside
return futures
Expand Down
22 changes: 12 additions & 10 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,17 +820,20 @@ def __init__(
## used for streaming update inference engine weights
self._held_gather_buffer = None

def init_collective(self, ip: str, port: int, world_size: int) -> None:
def init_collective(
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> None:
"""Initialize the collective communication."""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

if self.rank == 0:
pg = StatelessProcessGroup.create(
host=ip, port=port, rank=0, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)
# world_size = train_world_size + inference_world_size
# variable train_world_size is used in inference cluster
pg = StatelessProcessGroup.create(
host=ip, port=port, rank=self.rank, world_size=world_size
)
device = torch.cuda.current_device()
self.model_update_group = PyNcclCommunicator(pg, device=device)

def is_alive(self):
return True
Expand Down Expand Up @@ -1735,10 +1738,9 @@ def broadcast_weights_for_collective(self) -> None:
[self.model],
show_progress=False,
)
# broadcast from train rank0 worker to inference workers
# broadcast from train rank 0 to all other ranks (training and inference)
for _, tensor in hf_params_generator:
if self.rank == 0:
self.model_update_group.broadcast(tensor, src=0)
self.model_update_group.broadcast(tensor, src=0)

def prepare_for_lp_inference(self):
self.model = self.move_model(self.model, "cuda", move_grads=False)
Expand Down
109 changes: 109 additions & 0 deletions tests/unit/algorithms/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,115 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node():
setup(master_config, tokenizer, dataset, None)


def test_distillation_setup_non_colocated_smoke(monkeypatch):
"""Smoke test: calling setup with a non-colocated config should succeed."""
from unittest.mock import MagicMock, patch

import nemo_rl.algorithms.distillation as distil_mod

# Single node cluster; inference uses a subset of GPUs on same node
master_config = {
"policy": {
"generation": {
"backend": "vllm",
"colocated": {
"enabled": False,
"resources": {
"gpus_per_node": 8, # inference on 8 GPU
"num_nodes": 1,
},
},
},
"dtensor_cfg": {
"enabled": False,
},
"model_name": "test-policy",
},
"teacher": {
"model_name": "test-teacher",
"dtensor_cfg": {
"enabled": False,
},
},
"loss_fn": {
"kl_type": "forward",
"mixed_kl_weight": 0.5,
"zero_outside_topk": False,
},
"distillation": {
"seed": 42,
"topk_logits_k": 64,
"num_prompts_per_step": 1,
"val_period": 0,
"val_at_start": False,
},
"data": {"shuffle": False},
"logger": {},
"checkpointing": {},
"cluster": {"num_nodes": 2, "gpus_per_node": 8},
}

tokenizer = MagicMock()
dataset = MagicMock()
dataset.__len__ = MagicMock(return_value=1)

# Skip tokenizer/vocab equality check inside setup
monkeypatch.setenv("NRL_SKIP_DISTILLATION_TOKENIZER_CHECK", "1")

ip_port = ("127.0.0.1", 12345)

class DummyCluster:
def __init__(self, *args, **kwargs):
pass

def world_size(self):
return 1

def get_master_address_and_port(self):
return ip_port

class DummyPolicy:
def __init__(self, *args, **kwargs):
pass

def prepare_refit_info(self):
return {}

def init_collective(self, *args, **kwargs):
return [MagicMock()]

class DummyVllmGeneration:
def __init__(self, *args, **kwargs):
pass

def finish_generation(self):
return None

def prepare_refit_info(self, *args, **kwargs):
return None

def init_collective(self, *args, **kwargs):
return [MagicMock()]

with (
patch.object(distil_mod, "RayVirtualCluster", DummyCluster),
patch.object(distil_mod, "Logger"),
patch.object(distil_mod, "CheckpointManager") as mock_ckpt_mgr,
patch.object(distil_mod, "StatefulDataLoader"),
patch.object(distil_mod, "Policy", DummyPolicy),
patch.object(distil_mod, "VllmGeneration", DummyVllmGeneration),
patch.object(distil_mod, "ray") as mock_ray,
):
mock_ckpt_mgr.return_value.get_latest_checkpoint_path.return_value = None
mock_ray.get = MagicMock(return_value=None)

# Should not raise
result = distil_mod.setup(master_config, tokenizer, dataset, None)

# Basic shape check of returned tuple
assert isinstance(result, tuple)


def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node():
"""Test that non-colocated inference requires explicit gpus_per_node when cluster.num_nodes>1."""
from unittest.mock import MagicMock, patch
Expand Down
Loading
Loading