diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 58867c1890..2c3e08a618 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -905,14 +905,14 @@ def main(args: FlatArguments, tc: TokenizerConfig): avg_loss = sum_loss / total_fwd_passes metrics_to_log["train_loss"] = avg_loss else: - avg_loss_per_total_tok = sum_loss / total_tokens_this_log_period + avg_loss = sum_loss / total_tokens_this_log_period # The loss per pred tok is the closest analogue to what we report as the # avg_loss in the "mean" case avg_loss_per_pred_tok = sum_loss / pred_tokens_this_log_period total_optim_steps = args.logging_steps * accelerator.num_processes avg_sum_loss = sum_loss / total_optim_steps metrics_to_log["train_sum_loss"] = avg_sum_loss - metrics_to_log["train_loss_per_total_tok"] = avg_loss_per_total_tok + metrics_to_log["train_loss_per_total_tok"] = avg_loss metrics_to_log["train_loss_per_pred_tok"] = avg_loss_per_pred_tok if args.verbose: sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)