Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
grpo:
num_prompts_per_step: 8
num_generations_per_prompt: 8
max_num_steps: 100
max_num_steps: 1000000
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
Expand Down Expand Up @@ -30,6 +30,7 @@ policy:
learning_rate: 5.0e-6
logprob_batch_size: 4
max_total_sequence_length: 512
precision: "bfloat16"

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ policy:
learning_rate: 5.0e-6
logprob_batch_size: 2
max_total_sequence_length: 4096
precision: "bfloat16"

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
Expand Down
1 change: 1 addition & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ policy:
train_micro_batch_size: 2
learning_rate: 5.0e-6
max_total_sequence_length: 1024
precision: "float32"

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
Expand Down
1 change: 1 addition & 0 deletions nemo_reinforcer/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class PolicyConfig(TypedDict):
learning_rate: float
logprob_batch_size: int
generation: GenerationConfig
precision: str
10 changes: 8 additions & 2 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,23 @@ def __init__(
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
model_name = self.cfg["model_name"]
if self.cfg["precision"] == "float32":
dtype = torch.float32
elif self.cfg["precision"] == "bfloat16":
dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision: {self.cfg['precision']}")

print(f"[Rank {rank}] Loading model {model_name} on CPU...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)
self.reference_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):
"logprob_batch_size": 1,
"max_new_tokens": 16,
"do_sample": False,
"precision": "float32",
}

vllm_policy = None
Expand Down Expand Up @@ -437,6 +438,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size):
"logprob_batch_size": 1,
"max_new_tokens": 16,
"do_sample": False,
"precision": "float32",
}

hf_policy = HfPolicy(cluster, hf_config)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/models/policy/test_hf_ray_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"train_micro_batch_size": 1,
"learning_rate": 5e-6,
"logprob_batch_size": 1,
"precision": "float32",
"generation": {
"backend": "hf",
"temperature": 1.0,
Expand Down