Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,18 @@ def __init__(
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.bfloat16, # use half precision to save memory
torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)
self.reference_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.bfloat16, # use half precision to save memory
torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)

self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# If no pad token is defined, you might need:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

# ------------------------------------------------
# 3) Move to GPU + Composable FSDP
Expand All @@ -99,23 +99,10 @@ def do_fsdp(model):
# Create a device mesh with 'world_size' GPUs in a 1D arrangement.
mesh = init_device_mesh("cuda", (world_size,))

# Mixed precision training
# https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision
param_dtype = torch.bfloat16 # use lower precision for model parameters
reduce_dtype = torch.float32 # use higher precision for gradient reduction
buffer_dtype = torch.float32 # use higher precision for optimizer states

mp = MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
)

return FullyShardedDataParallel(
model,
device_mesh=mesh,
auto_wrap_policy=size_based_auto_wrap_policy,
mixed_precision=mp,
)

self.model.to("cuda")
Expand Down