Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/dpo.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add this key to all the configs/recipes?

Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ policy:
dtensor_cfg:
enabled: true
cpu_offload: False
torch_compile: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ policy:
dtensor_cfg:
enabled: true
cpu_offload: False
torch_compile: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ policy:
dtensor_cfg:
enabled: true
cpu_offload: False
torch_compile: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ policy:
dtensor_cfg:
enabled: true
cpu_offload: False
torch_compile: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ policy:
dtensor_cfg:
enabled: true
cpu_offload: False
torch_compile: False
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 4
Expand Down
7 changes: 7 additions & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
model_name = self.cfg["model_name"]

self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"]
self.torch_compile = self.cfg["dtensor_cfg"]["torch_compile"]
self.max_grad_norm = self.cfg["max_grad_norm"]

if self.cfg["precision"] == "float32":
Expand Down Expand Up @@ -195,6 +196,9 @@ def __init__(
custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"],
)

if self.torch_compile:
self.model = torch.compile(model)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try model.compile() instead? That should fix the _orig_mod issue. This is also the recommended way of compiling a model now. We'll work on throwing warnings and publicizing to raise awareness on this.


if self.cpu_offload:
self.model = self.move_buffer_to_device(self.model, "cpu")

Expand Down Expand Up @@ -736,6 +740,9 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]:
full_tensor = tensor.full_tensor()
else:
full_tensor = tensor
#torch.compile wraps the model as "_orig_mod", so remove the prefix here
if self.torch_compile and key.startswith("_orig_mod."):
key = key.removeprefix("_orig_mod.")
# Convert parameters to the configured dtype
converted_params[key] = full_tensor.to(self.dtype, non_blocking=True)

Expand Down
Loading