Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def __init__(
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/nemo-rl/issues/13 is fixed
)
# caching since this property is not always preserved after FSDP
self.num_tied_weights = len(find_tied_parameters(self.model))

self.tokenizer = tokenizer
# ------------------------------------------------
Expand Down Expand Up @@ -256,15 +258,14 @@ def train(
mbs: Optional[int] = None,
) -> Dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
num_tied_weights = len(find_tied_parameters(self.model))
skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK")
if (
num_tied_weights != 0
self.num_tied_weights != 0
and self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1
and not skip_tie_check
):
raise ValueError(
f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
)

if gbs is None:
Expand Down
7 changes: 4 additions & 3 deletions nemo_rl/models/policy/fsdp1_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/nemo-rl/issues/13 is fixed
)
# caching since this property is not always preserved after FSDP
self.num_tied_weights = len(find_tied_parameters(self.model))

if init_reference_model:
self.reference_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -229,11 +231,10 @@ def train(
) -> Dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
# Check if the model has tied weights
num_tied_weights = len(find_tied_parameters(self.model))
skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK")
if num_tied_weights != 0 and not skip_tie_check:
if self.num_tied_weights != 0 and not skip_tie_check:
raise ValueError(
f"Using FSP1 with a model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
f"Using FSP1 with a model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA/nemo-rl/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
)

if gbs is None:
Expand Down