Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,9 @@ def train(
else:
logits = outputs.logits

# Divide logits by temperature
logits.div_(self.cfg["generation"]["temperature"])

loss, loss_metrics = loss_fn(logits, mb)
num_valid_samples = loss_metrics["num_valid_samples"]
loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"]
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/models/policy/fsdp1_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def train(
else:
logits = outputs.logits

# Divide logits by temperature
logits.div_(self.cfg["generation"]["temperature"])

loss, loss_metrics = loss_fn(logits, mb)
num_valid_samples = loss_metrics["num_valid_samples"]
loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"]
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
},
"dtype": "bfloat16",
"max_new_tokens": 10,
"temperature": 1.0,
"temperature": 0.8,
"top_p": 1.0,
"top_k": None,
"stop_token_ids": None,
Expand Down Expand Up @@ -85,6 +85,9 @@ def get_basic_hf_test_config(enable_dtensor: bool = False) -> PolicyConfig:
},
"max_grad_norm": 1.0,
"make_sequence_length_divisible_by": 1,
"generation": {
"temperature": 0.8,
},
}


Expand Down
Loading