diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 39e0dd4974..27dfb4ac77 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -142,6 +142,8 @@ def __init__( device_map="cpu", # load weights onto CPU initially torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/nemo-rl/issues/13 is fixed ) + # caching since this property is not always preserved after FSDP + self.num_tied_weights = len(find_tied_parameters(self.model)) self.tokenizer = tokenizer # ------------------------------------------------ @@ -256,15 +258,14 @@ def train( mbs: Optional[int] = None, ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - num_tied_weights = len(find_tied_parameters(self.model)) skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK") if ( - num_tied_weights != 0 + self.num_tied_weights != 0 and self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1 and not skip_tie_check ): raise ValueError( - f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead." + f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead." ) if gbs is None: diff --git a/nemo_rl/models/policy/fsdp1_policy_worker.py b/nemo_rl/models/policy/fsdp1_policy_worker.py index b49c2748cf..b25e930b6f 100644 --- a/nemo_rl/models/policy/fsdp1_policy_worker.py +++ b/nemo_rl/models/policy/fsdp1_policy_worker.py @@ -94,6 +94,8 @@ def __init__( device_map="cpu", # load weights onto CPU initially torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/nemo-rl/issues/13 is fixed ) + # caching since this property is not always preserved after FSDP + self.num_tied_weights = len(find_tied_parameters(self.model)) if init_reference_model: self.reference_model = AutoModelForCausalLM.from_pretrained( @@ -229,11 +231,10 @@ def train( ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" # Check if the model has tied weights - num_tied_weights = len(find_tied_parameters(self.model)) skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK") - if num_tied_weights != 0 and not skip_tie_check: + if self.num_tied_weights != 0 and not skip_tie_check: raise ValueError( - f"Using FSP1 with a model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead." + f"Using FSP1 with a model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead." ) if gbs is None: