diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 47714fb0f5..fbe728a840 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -37,7 +37,7 @@ class PolicyConfig(TypedDict): train_micro_batch_size: int learning_rate: float logprob_batch_size: int - generation: GenerationConfig + generation: Optional[GenerationConfig] precision: str dtensor_cfg: DTensorConfig make_sequence_length_divisible_by: int diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 82c74cbf2b..c99110d7e7 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -335,6 +335,10 @@ def train( else: logits = outputs.logits + # Divide logits by temperature + if "generation" in self.cfg and self.cfg["generation"] is not None: + logits.div_(self.cfg["generation"]["temperature"]) + loss, loss_metrics = loss_fn(logits, mb) num_valid_samples = loss_metrics["num_valid_samples"] loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] diff --git a/nemo_rl/models/policy/fsdp1_policy_worker.py b/nemo_rl/models/policy/fsdp1_policy_worker.py index 6eb07d2a84..53ec5944f9 100644 --- a/nemo_rl/models/policy/fsdp1_policy_worker.py +++ b/nemo_rl/models/policy/fsdp1_policy_worker.py @@ -290,6 +290,10 @@ def train( else: logits = outputs.logits + # Divide logits by temperature + if "generation" in self.cfg and self.cfg["generation"] is not None: + logits.div_(self.cfg["generation"]["temperature"]) + loss, loss_metrics = loss_fn(logits, mb) num_valid_samples = loss_metrics["num_valid_samples"] loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 552ea3dae2..08f34defe2 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -36,7 +36,7 @@ }, "dtype": "bfloat16", "max_new_tokens": 10, - "temperature": 1.0, + "temperature": 0.8, "top_p": 1.0, "top_k": None, "stop_token_ids": None, @@ -85,6 +85,9 @@ def get_basic_hf_test_config(enable_dtensor: bool = False) -> PolicyConfig: }, "max_grad_norm": 1.0, "make_sequence_length_divisible_by": 1, + "generation": { + "temperature": 0.8, + }, } diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index 938ff84a05..f751c5c47e 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -61,6 +61,9 @@ "tensor_parallel_size": 1, }, "max_grad_norm": 1.0, + "generation": { + "temperature": 1.0, + }, }