Skip to content

Commit 1e87865

Browse files
Added perplexity computation.
1 parent 2281ce5 commit 1e87865

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

train.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# @ Description:
77
"""
88

9-
import os
9+
import argparse
1010

1111
import torch
1212
from torch import GradScaler, autocast
@@ -105,20 +105,24 @@ def run(model_type):
105105
labels = labels.view(-1).to(torch.long)
106106

107107
# We would take mean across all sequence length and all batches.
108-
loss = loss_fn(logits, labels) * batch_size
108+
loss = loss_fn(logits, labels)
109+
ppl = torch.exp(loss)
109110
scaler.scale(loss).backward()
110111
scaler.step(optimizer)
111112
scaler.update()
112113

113114
lr = lr_scheduler.step(g_step, optimizer)
114115

115116
logger.info(
116-
f"Epoch: {eps_num+1}/{train_config.num_epochs}, Batch: {batch_idx}/{len(train_loader)}, Batch Size: {batch_size}, Loss: {loss:.4f}, LR: {lr:.4f}"
117+
f"Epoch: {eps_num+1}/{train_config.num_epochs}, Batch: {batch_idx}/{len(train_loader)}, "
118+
f"Batch Size: {batch_size}, Loss: {loss:.4f}, "
119+
f"PPL: {ppl:.4f}, LR: {lr:.4f}"
117120
)
118121
metrics = {
119122
"Epoch": eps_num + 1,
120123
"Batch": batch_idx + 1,
121124
"Loss": loss,
125+
"Perplexity": ppl,
122126
"LR": lr,
123127
}
124128
if train_config.use_wandb:
@@ -127,6 +131,7 @@ def run(model_type):
127131

128132
model.eval()
129133
total_eval_loss = 0
134+
total_eval_ppl = 0
130135
with torch.no_grad():
131136
for input_ids, attn_mask, labels in tqdm(valid_loader):
132137
input_ids = input_ids.to(cuda, non_blocking=True)
@@ -140,11 +145,17 @@ def run(model_type):
140145
labels = labels.view(-1).to(torch.long)
141146

142147
# We would take mean across all sequence length and all batches.
143-
loss = loss_fn(logits, labels) * batch_size
148+
loss = loss_fn(logits, labels)
149+
ppl = torch.exp(loss)
144150
total_eval_loss += loss.item()
151+
total_eval_ppl += ppl.item()
145152

146153
avg_eval_loss = total_eval_loss / len(valid_loader)
147-
logger.info(f"Epoch {eps_num+1}, Evaluation Loss: {avg_eval_loss:.4f}")
154+
avg_eval_ppl = total_eval_ppl / len(valid_loader)
155+
logger.info(
156+
f"Epoch {eps_num+1}, Evaluation Loss: {avg_eval_loss:.4f}, "
157+
f"Evaluation Perplexity: {avg_eval_ppl:.4f}"
158+
)
148159
if train_config.use_wandb:
149160
metrics = {"Test Loss": loss}
150161
wandb.log(metrics, step=g_step)
@@ -156,6 +167,7 @@ def run(model_type):
156167
"epoch": eps_num,
157168
"global_step": g_step,
158169
"test_loss": avg_eval_loss,
170+
"test_ppl": avg_eval_ppl,
159171
"model": model.state_dict(),
160172
"optimizer": optimizer.state_dict(),
161173
"scaler": (
@@ -169,4 +181,9 @@ def run(model_type):
169181

170182

171183
if __name__ == "__main__":
172-
run(model_type="gpt")
184+
parser = argparse.ArgumentParser(description="TinyLLM Training help")
185+
parser.add_argument(
186+
"-m", "--model_type", type=str, help="Type of the model", required=True
187+
)
188+
args = parser.parse_args()
189+
run(model_type=args.model_type)

0 commit comments

Comments
 (0)