@@ -207,7 +207,7 @@ def _load_jailbreak_dataset(self):
207207 return samples
208208 except Exception as e :
209209 import traceback ; logger .error (f"Failed to load jailbreak dataset: { e } " ); traceback .print_exc (); return []
210-
210+
211211 def _save_checkpoint (self , epoch , global_step , optimizer , scheduler , checkpoint_dir , label_mappings ):
212212 os .makedirs (checkpoint_dir , exist_ok = True )
213213 latest_checkpoint_path = os .path .join (checkpoint_dir , 'latest_checkpoint.pt' )
@@ -223,7 +223,7 @@ def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_
223223 torch .save (state , latest_checkpoint_path )
224224 logger .info (f"Checkpoint saved for step { global_step } at { latest_checkpoint_path } " )
225225
226- def train (self , train_samples , val_samples , label_mappings , num_epochs = 3 , batch_size = 16 , learning_rate = 2e-5 ,
226+ def train (self , train_samples , val_samples , label_mappings , num_epochs = 3 , batch_size = 16 , learning_rate = 2e-5 ,
227227 checkpoint_dir = 'checkpoints' , resume = False , save_steps = 500 , checkpoint_to_load = None ):
228228 train_dataset = MultitaskDataset (train_samples , self .tokenizer )
229229 train_loader = DataLoader (train_dataset , batch_size = batch_size , shuffle = True )
@@ -252,7 +252,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
252252 pbar = tqdm (enumerate (train_loader ), total = len (train_loader ), desc = f"Epoch { epoch + 1 } " )
253253 for step , batch in pbar :
254254 if steps_to_skip > 0 and step < steps_to_skip : continue
255-
255+
256256 optimizer .zero_grad ()
257257 outputs = self .model (
258258 input_ids = batch ["input_ids" ].to (self .device ),
@@ -269,7 +269,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
269269 optimizer .step (); scheduler .step (); global_step += 1
270270 if global_step > 0 and global_step % save_steps == 0 :
271271 self ._save_checkpoint (epoch , global_step , optimizer , scheduler , checkpoint_dir , label_mappings )
272-
272+
273273 if val_loader : self .evaluate (val_loader )
274274 self ._save_checkpoint (epoch + 1 , global_step , optimizer , scheduler , checkpoint_dir , label_mappings )
275275 steps_to_skip = 0
@@ -325,13 +325,13 @@ def main():
325325 logger .info ("--- Starting Model Training ---" )
326326
327327 task_configs , label_mappings , checkpoint_to_load = {}, {}, None
328-
328+
329329 if args .resume :
330330 latest_checkpoint_path = os .path .join (args .checkpoint_dir , 'latest_checkpoint.pt' )
331331 if os .path .exists (latest_checkpoint_path ):
332332 logger .info (f"Resuming training from checkpoint: { latest_checkpoint_path } " )
333333 checkpoint_to_load = torch .load (latest_checkpoint_path , map_location = device )
334-
334+
335335 task_configs = checkpoint_to_load .get ('task_configs' )
336336 label_mappings = checkpoint_to_load .get ('label_mappings' )
337337
@@ -340,10 +340,10 @@ def main():
340340 logger .warning ("Cannot safely resume. Starting a fresh training run." )
341341 args .resume = False
342342 checkpoint_to_load = None
343- task_configs = {}
343+ task_configs = {}
344344 else :
345345 logger .info ("Loaded model configuration from checkpoint." )
346-
346+
347347 else :
348348 logger .warning (f"Resume flag is set, but no checkpoint found in '{ args .checkpoint_dir } '. Starting fresh run." )
349349 args .resume = False
@@ -358,12 +358,12 @@ def main():
358358 task_configs ["pii" ] = {"num_classes" : len (label_mappings ["pii" ]["label_mapping" ]["label_to_idx" ]), "weight" : 3.0 }
359359 if "jailbreak" in label_mappings :
360360 task_configs ["jailbreak" ] = {"num_classes" : len (label_mappings ["jailbreak" ]["label_mapping" ]["label_to_idx" ]), "weight" : 2.0 }
361-
361+
362362 if not task_configs :
363363 logger .error ("No tasks configured. Exiting." ); return
364364
365365 logger .info (f"Final task configurations: { task_configs } " )
366-
366+
367367 model = MultitaskBertModel (base_model_name , task_configs )
368368
369369 if args .resume and checkpoint_to_load :
@@ -379,9 +379,9 @@ def main():
379379 active_label_mappings = label_mappings if (args .resume and label_mappings ) else final_label_mappings
380380
381381 trainer = MultitaskTrainer (model , tokenizer , task_configs , device )
382-
382+
383383 logger .info (f"Total training samples: { len (train_samples )} " )
384-
384+
385385 trainer .train (
386386 train_samples , val_samples , active_label_mappings ,
387387 num_epochs = 10 , batch_size = 16 ,
@@ -390,12 +390,11 @@ def main():
390390 save_steps = args .save_steps ,
391391 checkpoint_to_load = checkpoint_to_load
392392 )
393-
393+
394394 trainer .save_model (output_path )
395395 with open (os .path .join (output_path , "label_mappings.json" ), "w" ) as f :
396396 json .dump (active_label_mappings , f , indent = 2 )
397397 logger .info ("Multitask training completed!" )
398398
399399if __name__ == "__main__" :
400400 main ()
401-
0 commit comments