diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index cf0f06bbfa..ac94d49120 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -25,7 +25,7 @@ FSDPModule, ) from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.modeling_utils import _get_tied_weight_keys +from transformers.integrations.accelerate import find_tied_parameters from nemo_reinforcer.models.dtensor.parallelize import _parallelize_model from nemo_reinforcer.algorithms.interfaces import LossFunction @@ -256,7 +256,7 @@ def train( mbs: Optional[int] = None, ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - num_tied_weights = len(_get_tied_weight_keys(self.model)) + num_tied_weights = len(find_tied_parameters(self.model)) skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK") if ( num_tied_weights != 0 diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index 89b46fd6ac..5e8a8f6bc5 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -39,7 +39,7 @@ ) from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.modeling_utils import _get_tied_weight_keys +from transformers.integrations.accelerate import find_tied_parameters from nemo_reinforcer.models.policy import PolicyConfig from nemo_reinforcer.models.policy.utils import import_class_from_path from nemo_reinforcer.distributed.virtual_cluster import ( @@ -229,7 +229,7 @@ def train( ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" # Check if the model has tied weights - num_tied_weights = len(_get_tied_weight_keys(self.model)) + num_tied_weights = len(find_tied_parameters(self.model)) skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK") if num_tied_weights != 0 and not skip_tie_check: raise ValueError(