diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 1a185300c5..eb4d58bc81 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -48,6 +48,7 @@ policy: dtensor_cfg: enabled: true cpu_offload: False + torch_compile: False sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 diff --git a/examples/configs/grpo-deepscaler-1.5b-8K.yaml b/examples/configs/grpo-deepscaler-1.5b-8K.yaml index eef1c4e205..2252479d34 100644 --- a/examples/configs/grpo-deepscaler-1.5b-8K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-8K.yaml @@ -46,6 +46,7 @@ policy: dtensor_cfg: enabled: true cpu_offload: False + torch_compile: False sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 2f9b6535e4..53fe421a2b 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -46,6 +46,7 @@ policy: dtensor_cfg: enabled: true cpu_offload: False + torch_compile: False sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index b8f01ce3e6..17d46b2751 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -35,6 +35,7 @@ policy: dtensor_cfg: enabled: true cpu_offload: False + torch_compile: False sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 1 diff --git a/examples/configs/sft_openmathinstruct2.yaml b/examples/configs/sft_openmathinstruct2.yaml index b885f7388b..133562d719 100644 --- a/examples/configs/sft_openmathinstruct2.yaml +++ b/examples/configs/sft_openmathinstruct2.yaml @@ -31,6 +31,7 @@ policy: dtensor_cfg: enabled: true cpu_offload: False + torch_compile: False sequence_parallel: false activation_checkpointing: false tensor_parallel_size: 4 diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index d13fc23c67..767d91a0d6 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -133,6 +133,7 @@ def __init__( model_name = self.cfg["model_name"] self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] + self.torch_compile = self.cfg["dtensor_cfg"]["torch_compile"] self.max_grad_norm = self.cfg["max_grad_norm"] if self.cfg["precision"] == "float32": @@ -195,6 +196,9 @@ def __init__( custom_parallel_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], ) + if self.torch_compile: + self.model = torch.compile(model) + if self.cpu_offload: self.model = self.move_buffer_to_device(self.model, "cpu") @@ -736,6 +740,9 @@ def get_weights_ipc_handles(self, keys: Iterable[str]) -> dict[str, Any]: full_tensor = tensor.full_tensor() else: full_tensor = tensor + #torch.compile wraps the model as "_orig_mod", so remove the prefix here + if self.torch_compile and key.startswith("_orig_mod."): + key = key.removeprefix("_orig_mod.") # Convert parameters to the configured dtype converted_params[key] = full_tensor.to(self.dtype, non_blocking=True)