Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9c04fe4
Use set_model_state_dict and load model on rank 0
parthchadha Jun 18, 2025
0887483
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jun 24, 2025
fcec2db
Fix use of model_config and remove duplicate args
parthchadha Jun 24, 2025
9b7bd0f
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jun 24, 2025
7235370
Disable nccl shm to fix https://github.com/NVIDIA-NeMo/RL/issues/564
parthchadha Jun 26, 2025
dadea4a
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jun 26, 2025
7871f65
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jun 26, 2025
8c0e061
Manually broadcast buffers; disable nccl shm conditionally
parthchadha Jun 27, 2025
b38c2ae
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jun 27, 2025
d81d27c
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jun 30, 2025
1bef005
Fix incorrect use of broadcast
parthchadha Jun 30, 2025
c35e5aa
Add nccl p2p disable in env variable to fix non-colocated failing tests
parthchadha Jun 30, 2025
d680c77
Add missing colocated field in unit test config
parthchadha Jun 30, 2025
26782fe
Use move to device for cpu offload instead of just moving buffers
parthchadha Jul 1, 2025
0f48e39
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jul 1, 2025
381707b
Check if generation exists in config before accessing it
parthchadha Jul 2, 2025
8e89bfb
Merge remote-tracking branch 'origin/main' into pchadha/large-model-s…
parthchadha Jul 2, 2025
140876e
Update eval.yaml with colocated
parthchadha Jul 2, 2025
b226eef
Merge branch 'main' into pchadha/large-model-state-dict-load
parthchadha Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/configs/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ generation:
pipeline_parallel_size: 1
gpu_memory_utilization: 0.9
max_model_len: 2048
colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
enabled: true
# only relevant when enabled is false
resources:
gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1
num_nodes: null # Decides number of nodes to be dedicated to generation


tokenizer:
name: ${generation.model_name} ## specify if you'd like to use a tokenizer different from the model's default
Expand Down
4 changes: 4 additions & 0 deletions nemo_rl/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ def _patch_vllm_init_workers_ray():
os.environ["VLLM_USE_V1"] = os.environ.get("NRL_VLLM_USE_V1", "1")
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"

if not self.cfg["colocated"]["enabled"]:
os.environ["NCCL_SHM_DISABLE"] = "1"
os.environ["NCCL_P2P_DISABLE"] = "1"

load_format = self.cfg["vllm_cfg"]["load_format"]
if ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(self.model_name):
load_format = "auto"
Expand Down
61 changes: 55 additions & 6 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

import ray
import torch
from accelerate import init_empty_weights
from torch import nn
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
set_model_state_dict,
)
from torch.distributed.fsdp import (
FSDPModule,
)
Expand All @@ -30,7 +35,7 @@
from torch.distributed.tensor.experimental._attention import (
set_rotate_method,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations.accelerate import find_tied_parameters
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM

Expand Down Expand Up @@ -137,6 +142,15 @@ def __init__(
init_reference_model: bool = True,
**kwargs: Any,
):
# Disable NCCL SHM if training and generation are not co-located: https://github.com/NVIDIA-NeMo/RL/issues/564
if (
"generation" in config
and config["generation"] is not None
and not config["generation"]["colocated"]["enabled"]
):
os.environ["NCCL_SHM_DISABLE"] = "1"
os.environ["NCCL_P2P_DISABLE"] = "1"

self.cfg = config
# torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call
torch.distributed.init_process_group(backend="nccl")
Expand All @@ -156,19 +170,38 @@ def __init__(
else:
raise ValueError(f"Unknown precision: {self.cfg['precision']}")

print(f"[Rank {self.rank}] Loading model {model_name} on CPU...")
self.model = AutoModelForCausalLM.from_pretrained(
model_config = AutoConfig.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
# Always load the model in float32 to keep master weights in float32.
# Keeping the master weights in lower precision has shown to cause issues with convergence.
# https://github.com/NVIDIA-NeMo/RL/issues/279 will fix the issue of CPU OOM for larger models.
torch_dtype=torch.float32,
trust_remote_code=True,
**sliding_window_overwrite(
model_name
), # due to https://github.com/huggingface/transformers/issues/38002
)

full_state_dict = None
if self.rank == 0:
print(f"[Rank {self.rank}] Loading model {model_name} on CPU...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
trust_remote_code=True,
config=model_config,
)
full_state_dict = model.state_dict()
del model

print(f"[Rank {self.rank}] Initializing empty model for FSDP...")
# All ranks initialize model on meta device, so FSDP can shard it.
# The actual weights will be broadcast from rank 0.

with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(
model_config,
)

# caching since this property is not always preserved after FSDP
self.num_tied_weights = len(find_tied_parameters(self.model))
self.skip_tie_check = os.environ.get(
Expand Down Expand Up @@ -222,8 +255,24 @@ def __init__(
custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"],
)

print(f"[Rank {self.rank}] Loading state dict from rank 0...")
# This will broadcast the state dict from rank 0 to all other ranks
# and load it into the FSDP model.
set_model_state_dict(
self.model,
model_state_dict=full_state_dict,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)

# Manually broadcast buffers
for _, buf in self.model.named_buffers():
torch.distributed.broadcast(buf, src=0)

if self.cpu_offload:
self.model = self.move_buffer_to_device(self.model, "cpu")
self.model = self.move_to_device(self.model, "cpu")

# used for streaming update inference engine weights
self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = (
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def create_test_config(
"top_k": None,
"stop_token_ids": None,
"stop_strings": None,
"colocated": {
"enabled": True,
"resources": {
"gpus_per_node": None,
"num_nodes": None,
},
},
},
"dtensor_cfg": {
"enabled": True,
Expand Down
Loading