Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ policy:
precision: "bfloat16"

dtensor_cfg:
_v2: true
env_vars:
PYTORCH_CUDA_ALLOC_CONF: "" # Refers to https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
enabled: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ policy:
max_total_sequence_length: 2048
precision: "bfloat16"
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: False
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ policy:
max_total_sequence_length: 8192
precision: "bfloat16"
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: False
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ policy:
max_total_sequence_length: 1024
precision: "bfloat16"
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: False
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ policy:
precision: "bfloat16"

dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ policy:
logprob_batch_size: 2

dtensor_cfg:
_v2: false
enabled: true
cpu_offload: true
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ policy:
logprob_batch_size: 2

dtensor_cfg:
_v2: false
enabled: true
cpu_offload: true
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ policy:
precision: "bfloat16"

dtensor_cfg:
_v2: false
enabled: true
cpu_offload: true
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 512
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 16384
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ policy:
precision: "bfloat16"

dtensor_cfg:
_v2: false
enabled: true
cpu_offload: true
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 4096
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 4096
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 512
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 16384
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 16384
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 4096
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ policy:
max_total_sequence_length: 512
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ policy:
max_total_sequence_length: 4096
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ policy:
max_total_sequence_length: 4096
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ policy:
max_total_sequence_length: 4096
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ policy:
max_total_sequence_length: 1024
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ policy:
max_total_sequence_length: 16000
precision: bfloat16
dtensor_cfg:
_v2: false
enabled: true
cpu_offload: false
sequence_parallel: true
Expand Down
1 change: 1 addition & 0 deletions examples/configs/rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ policy:
reward_model_type: "bradley_terry" # only "bradley_terry" is currently supported

dtensor_cfg:
_v2: true
enabled: true
cpu_offload: false
sequence_parallel: false
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ policy:
precision: "bfloat16"

dtensor_cfg:
_v2: true
enabled: true
cpu_offload: False
sequence_parallel: false
Expand Down
11 changes: 6 additions & 5 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(
pp_size = 1
cp_size = 1

megatron_enable = bool(config.get("megatron_cfg", {}).get("enabled", False))
dtensor_enable = bool(config.get("dtensor_cfg", {}).get("enabled", False))
megatron_enable = "megatron_cfg" in config and config["megatron_cfg"]["enabled"]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will enabled be sure to exist? This is different from original code. It will default to False if enabled doesn't exist.

dtensor_enable = "dtensor_cfg" in config and config["dtensor_cfg"]["enabled"]
if megatron_enable and dtensor_enable:
raise ValueError(
"Configure either Megatron (policy.megatron_cfg.enabled=true) or "
Expand All @@ -101,7 +101,7 @@ def __init__(
)

# Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility)
use_v2 = config.get("dtensor_cfg", {}).get("_v2", False)
use_v2 = config["dtensor_cfg"]["_v2"]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar, I guess you want to ensure this field must be set.

if use_v2:
worker_builder_cls = "nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
else:
Expand Down Expand Up @@ -646,9 +646,10 @@ def save_checkpoint(
) -> None:
"""Save a checkpoint of the model."""
# Only pass checkpointing_cfg for DTensor v2
use_v2 = self.cfg.get("dtensor_cfg", {}).get("_v2", False)
use_dtensor = "dtensor_cfg" in self.cfg and self.cfg["dtensor_cfg"]["enabled"]
use_dtensor_v2 = use_dtensor and self.cfg["dtensor_cfg"]["_v2"]

if use_v2:
if use_dtensor_v2:
futures = self.worker_group.run_all_workers_single_data(
"save_checkpoint",
weights_path=weights_path,
Expand Down
1 change: 1 addition & 0 deletions tests/functional/test_converter_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def create_test_config() -> Dict[str, Any]:
"max_total_sequence_length": 128,
"precision": "bfloat16",
"dtensor_cfg": {
"_v2": True,
"enabled": True,
"cpu_offload": False,
"sequence_parallel": False,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
},
},
"dtensor_cfg": {
"_v2": True,
"enabled": True,
"cpu_offload": False,
"sequence_parallel": False,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create_test_config(
},
},
"dtensor_cfg": {
**({"_v2": dtensor_v2} if dtensor_v2 else {}),
"_v2": dtensor_v2,
"enabled": True,
"cpu_offload": cpu_offload,
"sequence_parallel": sp,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/models/policy/test_dtensor_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def create_test_config(
},
},
"dtensor_cfg": {
**({"_v2": dtensor_v2} if dtensor_v2 else {}),
"_v2": dtensor_v2,
"enabled": True,
"cpu_offload": cpu_offload,
"sequence_parallel": sp,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/utils/test_native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
},
},
"dtensor_cfg": {
"_v2": False,
"enabled": True,
"cpu_offload": False,
"sequence_parallel": False,
Expand Down
Loading