6
6
# @ Description:
7
7
"""
8
8
9
- import os
9
+ import argparse
10
10
11
11
import torch
12
12
from torch import GradScaler , autocast
@@ -105,20 +105,24 @@ def run(model_type):
105
105
labels = labels .view (- 1 ).to (torch .long )
106
106
107
107
# We would take mean across all sequence length and all batches.
108
- loss = loss_fn (logits , labels ) * batch_size
108
+ loss = loss_fn (logits , labels )
109
+ ppl = torch .exp (loss )
109
110
scaler .scale (loss ).backward ()
110
111
scaler .step (optimizer )
111
112
scaler .update ()
112
113
113
114
lr = lr_scheduler .step (g_step , optimizer )
114
115
115
116
logger .info (
116
- f"Epoch: { eps_num + 1 } /{ train_config .num_epochs } , Batch: { batch_idx } /{ len (train_loader )} , Batch Size: { batch_size } , Loss: { loss :.4f} , LR: { lr :.4f} "
117
+ f"Epoch: { eps_num + 1 } /{ train_config .num_epochs } , Batch: { batch_idx } /{ len (train_loader )} , "
118
+ f"Batch Size: { batch_size } , Loss: { loss :.4f} , "
119
+ f"PPL: { ppl :.4f} , LR: { lr :.4f} "
117
120
)
118
121
metrics = {
119
122
"Epoch" : eps_num + 1 ,
120
123
"Batch" : batch_idx + 1 ,
121
124
"Loss" : loss ,
125
+ "Perplexity" : ppl ,
122
126
"LR" : lr ,
123
127
}
124
128
if train_config .use_wandb :
@@ -127,6 +131,7 @@ def run(model_type):
127
131
128
132
model .eval ()
129
133
total_eval_loss = 0
134
+ total_eval_ppl = 0
130
135
with torch .no_grad ():
131
136
for input_ids , attn_mask , labels in tqdm (valid_loader ):
132
137
input_ids = input_ids .to (cuda , non_blocking = True )
@@ -140,11 +145,17 @@ def run(model_type):
140
145
labels = labels .view (- 1 ).to (torch .long )
141
146
142
147
# We would take mean across all sequence length and all batches.
143
- loss = loss_fn (logits , labels ) * batch_size
148
+ loss = loss_fn (logits , labels )
149
+ ppl = torch .exp (loss )
144
150
total_eval_loss += loss .item ()
151
+ total_eval_ppl += ppl .item ()
145
152
146
153
avg_eval_loss = total_eval_loss / len (valid_loader )
147
- logger .info (f"Epoch { eps_num + 1 } , Evaluation Loss: { avg_eval_loss :.4f} " )
154
+ avg_eval_ppl = total_eval_ppl / len (valid_loader )
155
+ logger .info (
156
+ f"Epoch { eps_num + 1 } , Evaluation Loss: { avg_eval_loss :.4f} , "
157
+ f"Evaluation Perplexity: { avg_eval_ppl :.4f} "
158
+ )
148
159
if train_config .use_wandb :
149
160
metrics = {"Test Loss" : loss }
150
161
wandb .log (metrics , step = g_step )
@@ -156,6 +167,7 @@ def run(model_type):
156
167
"epoch" : eps_num ,
157
168
"global_step" : g_step ,
158
169
"test_loss" : avg_eval_loss ,
170
+ "test_ppl" : avg_eval_ppl ,
159
171
"model" : model .state_dict (),
160
172
"optimizer" : optimizer .state_dict (),
161
173
"scaler" : (
@@ -169,4 +181,9 @@ def run(model_type):
169
181
170
182
171
183
if __name__ == "__main__" :
172
- run (model_type = "gpt" )
184
+ parser = argparse .ArgumentParser (description = "TinyLLM Training help" )
185
+ parser .add_argument (
186
+ "-m" , "--model_type" , type = str , help = "Type of the model" , required = True
187
+ )
188
+ args = parser .parse_args ()
189
+ run (model_type = args .model_type )
0 commit comments