From 9c04fe4b262e9b8355abecb4980e1c091fb7ea81 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 18 Jun 2025 23:15:25 +0000 Subject: [PATCH 01/10] Use set_model_state_dict and load model on rank 0 Signed-off-by: Parth Chadha --- .../models/policy/dtensor_policy_worker.py | 60 +++++++++++++++---- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index ec3f7afc0c..40748ac74f 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -30,7 +30,7 @@ from torch.distributed.tensor.experimental._attention import ( set_rotate_method, ) -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.integrations.accelerate import find_tied_parameters from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM @@ -58,6 +58,10 @@ load_checkpoint, save_checkpoint, ) +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, +) @contextmanager @@ -151,19 +155,41 @@ def __init__( else: raise ValueError(f"Unknown precision: {self.cfg['precision']}") - print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") - self.model = AutoModelForCausalLM.from_pretrained( + model_config = AutoConfig.from_pretrained( model_name, - device_map="cpu", # load weights onto CPU initially - # Always load the model in float32 to keep master weights in float32. - # Keeping the master weights in lower precision has shown to cause issues with convergence. - # https://github.com/NVIDIA/NeMo-RL/issues/279 will fix the issue of CPU OOM for larger models. - torch_dtype=torch.float32, trust_remote_code=True, - **sliding_window_overwrite( - model_name - ), # due to https://github.com/huggingface/transformers/issues/38002 + **sliding_window_overwrite(model_name), # due to https://github.com/huggingface/transformers/issues/38002 ) + + full_state_dict = None + if self.rank == 0: + print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cpu", # load weights onto CPU initially + # Always load the model in float32 to keep master weights in float32. + # Keeping the master weights in lower precision has shown to cause issues with convergence. + # https://github.com/NVIDIA/NeMo-RL/issues/279 will fix the issue of CPU OOM for larger models. + torch_dtype=torch.float32, + trust_remote_code=True, + config=model_config, + ) + full_state_dict = model.state_dict() + del model + torch.cuda.empty_cache() + + print(f"[Rank {self.rank}] Initializing empty model for FSDP...") + # All ranks initialize model on meta device, so FSDP can shard it. + # The actual weights will be broadcast from rank 0. + from accelerate import init_empty_weights + + with init_empty_weights(): + self.model = AutoModelForCausalLM.from_config( + model_config, + torch_dtype=self.dtype, + trust_remote_code=True, + ) + # caching since this property is not always preserved after FSDP self.num_tied_weights = len(find_tied_parameters(self.model)) self.skip_tie_check = os.environ.get( @@ -217,6 +243,18 @@ def __init__( custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], ) + print(f"[Rank {self.rank}] Loading state dict from rank 0...") + # This will broadcast the state dict from rank 0 to all other ranks + # and load it into the FSDP model. + set_model_state_dict( + self.model, + model_state_dict=full_state_dict, + options=StateDictOptions( + full_state_dict=True, + broadcast_from_rank0=True, + ), + ) + if self.cpu_offload: self.model = self.move_buffer_to_device(self.model, "cpu") From fcec2db906f29bff8d8ce983064746c401b41c0f Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Tue, 24 Jun 2025 17:17:04 +0000 Subject: [PATCH 02/10] Fix use of model_config and remove duplicate args Signed-off-by: Parth Chadha --- .../models/policy/dtensor_policy_worker.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 560891f316..0df1b3439a 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -21,7 +21,12 @@ import ray import torch +from accelerate import init_empty_weights from torch import nn +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, +) from torch.distributed.fsdp import ( FSDPModule, ) @@ -59,10 +64,6 @@ load_checkpoint, save_checkpoint, ) -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - set_model_state_dict, -) @contextmanager @@ -162,8 +163,13 @@ def __init__( model_config = AutoConfig.from_pretrained( model_name, + # Always load the model in float32 to keep master weights in float32. + # Keeping the master weights in lower precision has shown to cause issues with convergence. + torch_dtype=torch.float32, trust_remote_code=True, - **sliding_window_overwrite(model_name), # due to https://github.com/huggingface/transformers/issues/38002 + **sliding_window_overwrite( + model_name + ), # due to https://github.com/huggingface/transformers/issues/38002 ) full_state_dict = None @@ -172,27 +178,19 @@ def __init__( model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - # Always load the model in float32 to keep master weights in float32. - # Keeping the master weights in lower precision has shown to cause issues with convergence. - # https://github.com/NVIDIA/NeMo-RL/issues/279 will fix the issue of CPU OOM for larger models. - torch_dtype=torch.float32, trust_remote_code=True, config=model_config, ) full_state_dict = model.state_dict() del model - torch.cuda.empty_cache() print(f"[Rank {self.rank}] Initializing empty model for FSDP...") # All ranks initialize model on meta device, so FSDP can shard it. # The actual weights will be broadcast from rank 0. - from accelerate import init_empty_weights with init_empty_weights(): self.model = AutoModelForCausalLM.from_config( model_config, - torch_dtype=self.dtype, - trust_remote_code=True, ) # caching since this property is not always preserved after FSDP From 7235370f3854dd9ae95ccaf09a926e9f45e2d0bb Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Thu, 26 Jun 2025 18:19:27 +0000 Subject: [PATCH 03/10] Disable nccl shm to fix https://github.com/NVIDIA-NeMo/RL/issues/564 Signed-off-by: Parth Chadha --- nemo_rl/models/policy/dtensor_policy_worker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 0df1b3439a..d88248d70c 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -142,6 +142,8 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + # Disable NCCL SHM : https://github.com/NVIDIA-NeMo/RL/issues/564 + os.environ["NCCL_SHM_DISABLE"] = "1" self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call torch.distributed.init_process_group(backend="nccl") From 8c0e061472d7cd36a79e6afb67082e31ec082692 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Fri, 27 Jun 2025 20:41:55 +0000 Subject: [PATCH 04/10] Manually broadcast buffers; disable nccl shm conditionally Signed-off-by: Parth Chadha --- nemo_rl/models/policy/dtensor_policy_worker.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 3cd213cf88..5f5e532c53 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -142,8 +142,13 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): - # Disable NCCL SHM : https://github.com/NVIDIA-NeMo/RL/issues/564 - os.environ["NCCL_SHM_DISABLE"] = "1" + # Disable NCCL SHM if training and generation are not co-located: https://github.com/NVIDIA-NeMo/RL/issues/564 + if ( + config["generation"] is not None + and not config["generation"]["colocated"]["enabled"] + ): + os.environ["NCCL_SHM_DISABLE"] = "1" + self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call torch.distributed.init_process_group(backend="nccl") @@ -260,6 +265,10 @@ def __init__( ), ) + # Manually broadcast buffers + for _, buf in self.model.named_buffers(): + torch.distributed.broadcast(buf, src=0, group=self.dp_mesh.get_group()) + if self.cpu_offload: self.model = self.move_buffer_to_device(self.model, "cpu") From 1bef005fd9c6d2060e5a335120ea547e3ffaadc3 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Mon, 30 Jun 2025 07:27:20 -0700 Subject: [PATCH 05/10] Fix incorrect use of broadcast Signed-off-by: Parth Chadha --- nemo_rl/models/policy/dtensor_policy_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 5f5e532c53..f5b60f5d72 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -267,7 +267,7 @@ def __init__( # Manually broadcast buffers for _, buf in self.model.named_buffers(): - torch.distributed.broadcast(buf, src=0, group=self.dp_mesh.get_group()) + torch.distributed.broadcast(buf, src=0) if self.cpu_offload: self.model = self.move_buffer_to_device(self.model, "cpu") From c35e5aa0e8648b3059127a42a66b7e7b8a6d4643 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Mon, 30 Jun 2025 14:38:56 -0700 Subject: [PATCH 06/10] Add nccl p2p disable in env variable to fix non-colocated failing tests Signed-off-by: Parth Chadha --- nemo_rl/models/generation/vllm.py | 4 ++++ nemo_rl/models/policy/dtensor_policy_worker.py | 1 + 2 files changed, 5 insertions(+) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index f0cd5eb50b..c138607639 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -316,6 +316,10 @@ def _patch_vllm_init_workers_ray(): os.environ["VLLM_USE_V1"] = os.environ.get("NRL_VLLM_USE_V1", "1") os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + if not self.cfg["colocated"]["enabled"]: + os.environ["NCCL_SHM_DISABLE"] = "1" + os.environ["NCCL_P2P_DISABLE"] = "1" + load_format = self.cfg["vllm_cfg"]["load_format"] if ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(self.model_name): load_format = "auto" diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index f5b60f5d72..35c91334bb 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -148,6 +148,7 @@ def __init__( and not config["generation"]["colocated"]["enabled"] ): os.environ["NCCL_SHM_DISABLE"] = "1" + os.environ["NCCL_P2P_DISABLE"] = "1" self.cfg = config # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call From d680c77de8ed47dacfcf87e0a1c8bbb9181749ef Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Mon, 30 Jun 2025 16:06:44 -0700 Subject: [PATCH 07/10] Add missing colocated field in unit test config Signed-off-by: Parth Chadha --- tests/unit/models/policy/test_dtensor_worker.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 0a42ea1e9f..91bf140641 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -61,6 +61,13 @@ def create_test_config( "top_k": None, "stop_token_ids": None, "stop_strings": None, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, }, "dtensor_cfg": { "enabled": True, From 26782fee729003e35debf3f6c29d31516776a4b7 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Tue, 1 Jul 2025 09:49:34 -0700 Subject: [PATCH 08/10] Use move to device for cpu offload instead of just moving buffers Signed-off-by: Parth Chadha --- nemo_rl/models/policy/dtensor_policy_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 35c91334bb..87f875f2f2 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -271,7 +271,7 @@ def __init__( torch.distributed.broadcast(buf, src=0) if self.cpu_offload: - self.model = self.move_buffer_to_device(self.model, "cpu") + self.model = self.move_to_device(self.model, "cpu") # used for streaming update inference engine weights self._held_sharded_state_dict_reference: Optional[dict[str, torch.Tensor]] = ( From 381707b11202c29df6060b0cf3f897a63c861a9f Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Tue, 1 Jul 2025 22:44:45 -0700 Subject: [PATCH 09/10] Check if generation exists in config before accessing it Signed-off-by: Parth Chadha --- nemo_rl/models/policy/dtensor_policy_worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 87f875f2f2..078eacd013 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -144,7 +144,8 @@ def __init__( ): # Disable NCCL SHM if training and generation are not co-located: https://github.com/NVIDIA-NeMo/RL/issues/564 if ( - config["generation"] is not None + "generation" in config + and config["generation"] is not None and not config["generation"]["colocated"]["enabled"] ): os.environ["NCCL_SHM_DISABLE"] = "1" From 140876e7c62bba6b822e488938d82351ab21dc65 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 2 Jul 2025 07:03:41 -0700 Subject: [PATCH 10/10] Update eval.yaml with colocated Signed-off-by: Parth Chadha --- examples/configs/eval.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/configs/eval.yaml b/examples/configs/eval.yaml index 0308f65ed6..e880d98bc7 100644 --- a/examples/configs/eval.yaml +++ b/examples/configs/eval.yaml @@ -22,6 +22,15 @@ generation: pipeline_parallel_size: 1 gpu_memory_utilization: 0.9 max_model_len: 2048 + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + tokenizer: name: ${generation.model_name} ## specify if you'd like to use a tokenizer different from the model's default