diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 7d9d0ba7b2..e97fd10522 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -45,6 +45,7 @@ policy: precision: "bfloat16" dtensor_cfg: + _v2: true env_vars: PYTORCH_CUDA_ALLOC_CONF: "" # Refers to https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf enabled: true diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml index 72dcb9ad1e..c976deba6a 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2.yaml @@ -32,6 +32,7 @@ policy: max_total_sequence_length: 2048 precision: "bfloat16" dtensor_cfg: + _v2: false enabled: true cpu_offload: False sequence_parallel: false diff --git a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp4.yaml b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp4.yaml index 22851b368c..34c37325f9 100644 --- a/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp4.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.1-8b-instruct-4n8g-fsdp2tp4.yaml @@ -32,6 +32,7 @@ policy: max_total_sequence_length: 8192 precision: "bfloat16" dtensor_cfg: + _v2: false enabled: true cpu_offload: False sequence_parallel: false diff --git a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml index 22870f0e66..ecaaa4c5ee 100644 --- a/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml +++ b/examples/configs/recipes/llm/dpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml @@ -33,6 +33,7 @@ policy: max_total_sequence_length: 1024 precision: "bfloat16" dtensor_cfg: + _v2: false enabled: true cpu_offload: False sequence_parallel: false diff --git a/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml index 86a3a6fc97..41c130fb70 100644 --- a/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml +++ b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml @@ -42,6 +42,7 @@ policy: precision: "bfloat16" dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-16K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-16K.yaml index 570fecb1b9..dbbfa09e4d 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-16K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-16K.yaml @@ -11,6 +11,7 @@ policy: logprob_batch_size: 2 dtensor_cfg: + _v2: false enabled: true cpu_offload: true sequence_parallel: true diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml index 3cd8fabd6d..d5664e6fd6 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-24K.yaml @@ -11,6 +11,7 @@ policy: logprob_batch_size: 2 dtensor_cfg: + _v2: false enabled: true cpu_offload: true sequence_parallel: true diff --git a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml index f6cc626890..d8d68be324 100644 --- a/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml @@ -48,6 +48,7 @@ policy: precision: "bfloat16" dtensor_cfg: + _v2: false enabled: true cpu_offload: true sequence_parallel: true diff --git a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml index 091cb2909a..d9274728c3 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 512 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml index 4c3351970c..a4f5c9eb3b 100644 --- a/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-gemma3-27b-it-8n8g-fsdp2tp8-actckpt-long.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 16384 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml index e1b7c4d809..d5565feb5a 100644 --- a/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml @@ -49,6 +49,7 @@ policy: precision: "bfloat16" dtensor_cfg: + _v2: false enabled: true cpu_offload: true sequence_parallel: true diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml index 17b474bd72..a61331048e 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 4096 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml index 1c2b3840ca..29e6607176 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 4096 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index eddf09bf97..f2c8c30b4b 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 512 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml index 7fd4007279..e8c4e1a631 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt-long.v3.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 16384 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: true diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml index f163092404..c229f7fe4e 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 16384 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: true diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml index f6ecc1e390..db86a9edb5 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 4096 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: true diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml index b8f79eb6ae..7edc4372f1 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -43,6 +43,7 @@ policy: max_total_sequence_length: 512 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml index d7906b82e0..beb2767e9d 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml @@ -24,6 +24,7 @@ policy: max_total_sequence_length: 4096 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml index 1fc0ccec7c..ef8e4ee884 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-long.yaml @@ -24,6 +24,7 @@ policy: max_total_sequence_length: 4096 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2sp.yaml b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2sp.yaml index 8c3f14b531..89b0197a7a 100644 --- a/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2sp.yaml +++ b/examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp2sp.yaml @@ -24,6 +24,7 @@ policy: max_total_sequence_length: 4096 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: true diff --git a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml index 165e2fa9a3..338e495e72 100644 --- a/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.yaml @@ -25,6 +25,7 @@ policy: max_total_sequence_length: 1024 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml index 800d94711e..3d8f70a1ee 100644 --- a/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml +++ b/examples/configs/recipes/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.yaml @@ -25,6 +25,7 @@ policy: max_total_sequence_length: 16000 precision: bfloat16 dtensor_cfg: + _v2: false enabled: true cpu_offload: false sequence_parallel: true diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index 0302cd8236..2c037eea1e 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -39,6 +39,7 @@ policy: reward_model_type: "bradley_terry" # only "bradley_terry" is currently supported dtensor_cfg: + _v2: true enabled: true cpu_offload: false sequence_parallel: false diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index b8ef6c1626..e2f49acfe3 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -32,6 +32,7 @@ policy: precision: "bfloat16" dtensor_cfg: + _v2: true enabled: true cpu_offload: False sequence_parallel: false diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 29e2dfdf0d..52f4114499 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -77,8 +77,8 @@ def __init__( pp_size = 1 cp_size = 1 - megatron_enable = bool(config.get("megatron_cfg", {}).get("enabled", False)) - dtensor_enable = bool(config.get("dtensor_cfg", {}).get("enabled", False)) + megatron_enable = "megatron_cfg" in config and config["megatron_cfg"]["enabled"] + dtensor_enable = "dtensor_cfg" in config and config["dtensor_cfg"]["enabled"] if megatron_enable and dtensor_enable: raise ValueError( "Configure either Megatron (policy.megatron_cfg.enabled=true) or " @@ -101,7 +101,7 @@ def __init__( ) # Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility) - use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) + use_v2 = config["dtensor_cfg"]["_v2"] if use_v2: worker_builder_cls = "nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" else: @@ -646,9 +646,10 @@ def save_checkpoint( ) -> None: """Save a checkpoint of the model.""" # Only pass checkpointing_cfg for DTensor v2 - use_v2 = self.cfg.get("dtensor_cfg", {}).get("_v2", False) + use_dtensor = "dtensor_cfg" in self.cfg and self.cfg["dtensor_cfg"]["enabled"] + use_dtensor_v2 = use_dtensor and self.cfg["dtensor_cfg"]["_v2"] - if use_v2: + if use_dtensor_v2: futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", weights_path=weights_path, diff --git a/tests/functional/test_converter_roundtrip.py b/tests/functional/test_converter_roundtrip.py index 7e6bf8d4a7..0d83c6b677 100644 --- a/tests/functional/test_converter_roundtrip.py +++ b/tests/functional/test_converter_roundtrip.py @@ -70,6 +70,7 @@ def create_test_config() -> Dict[str, Any]: "max_total_sequence_length": 128, "precision": "bfloat16", "dtensor_cfg": { + "_v2": True, "enabled": True, "cpu_offload": False, "sequence_parallel": False, diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 140efca56b..c56939a388 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -104,6 +104,7 @@ }, }, "dtensor_cfg": { + "_v2": True, "enabled": True, "cpu_offload": False, "sequence_parallel": False, diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index c81ae15dcf..82dbefa025 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -66,7 +66,7 @@ def create_test_config( }, }, "dtensor_cfg": { - **({"_v2": dtensor_v2} if dtensor_v2 else {}), + "_v2": dtensor_v2, "enabled": True, "cpu_offload": cpu_offload, "sequence_parallel": sp, diff --git a/tests/unit/models/policy/test_dtensor_worker_v2.py b/tests/unit/models/policy/test_dtensor_worker_v2.py index a16e3afda5..b23c78a4d7 100644 --- a/tests/unit/models/policy/test_dtensor_worker_v2.py +++ b/tests/unit/models/policy/test_dtensor_worker_v2.py @@ -55,7 +55,7 @@ def create_test_config( }, }, "dtensor_cfg": { - **({"_v2": dtensor_v2} if dtensor_v2 else {}), + "_v2": dtensor_v2, "enabled": True, "cpu_offload": cpu_offload, "sequence_parallel": sp, diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index 88003941cb..8e64da8d4c 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -52,6 +52,7 @@ }, }, "dtensor_cfg": { + "_v2": False, "enabled": True, "cpu_offload": False, "sequence_parallel": False,