From c58334f5a106cd5f2e15c3f7fdfaba45df787292 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 21 Mar 2025 04:14:09 -0700 Subject: [PATCH 1/2] Make precision configurable Signed-off-by: Sahil Jain --- examples/configs/grpo_math_1B.yaml | 3 ++- examples/configs/grpo_math_8B.yaml | 1 + examples/configs/sft.yaml | 1 + examples/run_grpo_math.py | 4 +++- nemo_reinforcer/algorithms/grpo.py | 4 +++- nemo_reinforcer/models/policy/hf_policy.py | 10 ++++++++-- 6 files changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 684e24e7f5..f1884b9f75 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -2,7 +2,7 @@ grpo: num_prompts_per_step: 8 num_generations_per_prompt: 8 - max_num_steps: 100 + max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true val_period: 10 @@ -30,6 +30,7 @@ policy: learning_rate: 5.0e-6 logprob_batch_size: 4 max_total_sequence_length: 512 + precision: "bfloat16" scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 2415ee8cd8..69f28a41b2 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -9,6 +9,7 @@ policy: learning_rate: 5.0e-6 logprob_batch_size: 2 max_total_sequence_length: 4096 + precision: "bfloat16" scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 38db492017..ef049bea88 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -21,6 +21,7 @@ policy: train_micro_batch_size: 2 learning_rate: 5.0e-6 max_total_sequence_length: 1024 + precision: "float32" scheduler: - name: "torch.optim.lr_scheduler.LinearLR" diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 4e7732def1..1c30f8ea94 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -170,7 +170,9 @@ def main(): args, overrides = parse_args() if not args.config: - args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo_math_1B.yaml") + args.config = os.path.join( + os.path.dirname(__file__), "configs", "grpo_math_1B.yaml" + ) config = load_config(args.config) print(f"Loaded configuration from: {args.config}") diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 94b9f29a2b..bc4f96099a 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -445,7 +445,9 @@ def grpo_train( # Run grpo training (single-turn) for batch in dataloader: - print(f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}") + print( + f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}" + ) with timer.time("total_step_time"): # Prepare batch diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 6f44574c71..1a5d23232e 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -72,17 +72,23 @@ def __init__( rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] + if self.cfg["precision"] == "float32": + dtype = torch.float32 + elif self.cfg["precision"] == "bfloat16": + dtype = torch.bfloat16 + else: + raise ValueError(f"Unknown precision: {self.cfg['precision']}") print(f"[Rank {rank}] Loading model {model_name} on CPU...") self.model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) self.reference_model = AutoModelForCausalLM.from_pretrained( model_name, device_map="cpu", # load weights onto CPU initially - torch_dtype=torch.float32, # use full precision until https://github.com/NVIDIA/reinforcer/issues/13 is fixed + torch_dtype=dtype, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) From 18137f70975fcb2145a227faca3afc89cb5c1658 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 21 Mar 2025 04:32:47 -0700 Subject: [PATCH 2/2] configure policy tests to use float32 Signed-off-by: Sahil Jain --- nemo_reinforcer/models/policy/__init__.py | 1 + tests/unit/models/generation/test_vllm_generation.py | 2 ++ tests/unit/models/policy/test_hf_ray_policy.py | 1 + 3 files changed, 4 insertions(+) diff --git a/nemo_reinforcer/models/policy/__init__.py b/nemo_reinforcer/models/policy/__init__.py index d28d7a91c0..ee2bf2389e 100644 --- a/nemo_reinforcer/models/policy/__init__.py +++ b/nemo_reinforcer/models/policy/__init__.py @@ -24,3 +24,4 @@ class PolicyConfig(TypedDict): learning_rate: float logprob_batch_size: int generation: GenerationConfig + precision: str diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index adf1aad824..8e013cc1ea 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -191,6 +191,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): "logprob_batch_size": 1, "max_new_tokens": 16, "do_sample": False, + "precision": "float32", } vllm_policy = None @@ -437,6 +438,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): "logprob_batch_size": 1, "max_new_tokens": 16, "do_sample": False, + "precision": "float32", } hf_policy = HfPolicy(cluster, hf_config) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 0f51fbf792..9b825e5302 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -32,6 +32,7 @@ "train_micro_batch_size": 1, "learning_rate": 5e-6, "logprob_batch_size": 1, + "precision": "float32", "generation": { "backend": "hf", "temperature": 1.0,