diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 684e24e7f5..f1884b9f75 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -2,7 +2,7 @@ grpo: num_prompts_per_step: 8 num_generations_per_prompt: 8 - max_num_steps: 100 + max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true val_period: 10 @@ -30,6 +30,7 @@ policy: learning_rate: 5.0e-6 logprob_batch_size: 4 max_total_sequence_length: 512 + precision: "bfloat16" scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 2415ee8cd8..69f28a41b2 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -9,6 +9,7 @@ policy: learning_rate: 5.0e-6 logprob_batch_size: 2 max_total_sequence_length: 4096 + precision: "bfloat16" scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 38db492017..ef049bea88 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -21,6 +21,7 @@ policy: train_micro_batch_size: 2 learning_rate: 5.0e-6 max_total_sequence_length: 1024 + precision: "float32" scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/nemo_reinforcer/models/policy/__init__.py b/nemo_reinforcer/models/policy/__init__.py index d28d7a91c0..ee2bf2389e 100644 --- a/nemo_reinforcer/models/policy/__init__.py +++ b/nemo_reinforcer/models/policy/__init__.py @@ -24,3 +24,4 @@ class PolicyConfig(TypedDict): learning_rate: float logprob_batch_size: int generation: GenerationConfig + precision: str diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 6f44574c71..1a5d23232e 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -72,17 +72,23 @@ def __init__( rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] + if self.cfg["precision"] == "float32": + dtype = torch.float32 + elif self.cfg["precision"] == "bfloat16": + dtype = torch.bfloat16 + else: + raise ValueError(f"Unknown precision: {self.cfg['precision']}") print(f"[Rank {rank}] Loading model {model_name} on CPU...") self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) self.reference_model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index adf1aad824..8e013cc1ea 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -191,6 +191,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): "logprob_batch_size": 1, "max_new_tokens": 16, "do_sample": False, + "precision": "float32", } vllm_policy = None @@ -437,6 +438,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): "logprob_batch_size": 1, "max_new_tokens": 16, "do_sample": False, + "precision": "float32", } hf_policy = HfPolicy(cluster, hf_config) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 0f51fbf792..9b825e5302 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -32,6 +32,7 @@ "train_micro_batch_size": 1, "learning_rate": 5e-6, "logprob_batch_size": 1, + "precision": "float32", "generation": { "backend": "hf", "temperature": 1.0,