diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 0bf4e71477..b0d544aabb 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -496,9 +496,9 @@ def __init__( model_cfg.pipeline_dtype = dtype_map[self.cfg["megatron_cfg"]["pipeline_dtype"]] model_cfg.parallel_output = True if self.cfg["megatron_cfg"]["activation_checkpointing"]: - model_cfg.activations_checkpoint_granularity = "full" - model_cfg.activations_checkpoint_method = "uniform" - model_cfg.activations_checkpoint_num_layers = 1 + model_cfg.recompute_granularity = "full" + model_cfg.recompute_method = "uniform" + model_cfg.recompute_num_layers = 1 if not model_cfg.gated_linear_unit: assert model_cfg.activation_func is not None, ( "activation_func must be set if not using gated_linear_unit. This likely "