Skip to content

Commit e8969ce

Browse files
committed
chore: fix pre-commit failures in #353
Signed-off-by: Huamin Chen <[email protected]>
1 parent 7e3c54d commit e8969ce

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/training/multitask_wiki_classifier_training.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _load_jailbreak_dataset(self):
207207
return samples
208208
except Exception as e:
209209
import traceback; logger.error(f"Failed to load jailbreak dataset: {e}"); traceback.print_exc(); return []
210-
210+
211211
def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_dir, label_mappings):
212212
os.makedirs(checkpoint_dir, exist_ok=True)
213213
latest_checkpoint_path = os.path.join(checkpoint_dir, 'latest_checkpoint.pt')
@@ -223,7 +223,7 @@ def _save_checkpoint(self, epoch, global_step, optimizer, scheduler, checkpoint_
223223
torch.save(state, latest_checkpoint_path)
224224
logger.info(f"Checkpoint saved for step {global_step} at {latest_checkpoint_path}")
225225

226-
def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_size=16, learning_rate=2e-5,
226+
def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_size=16, learning_rate=2e-5,
227227
checkpoint_dir='checkpoints', resume=False, save_steps=500, checkpoint_to_load=None):
228228
train_dataset = MultitaskDataset(train_samples, self.tokenizer)
229229
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
@@ -252,7 +252,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
252252
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")
253253
for step, batch in pbar:
254254
if steps_to_skip > 0 and step < steps_to_skip: continue
255-
255+
256256
optimizer.zero_grad()
257257
outputs = self.model(
258258
input_ids=batch["input_ids"].to(self.device),
@@ -269,7 +269,7 @@ def train(self, train_samples, val_samples, label_mappings, num_epochs=3, batch_
269269
optimizer.step(); scheduler.step(); global_step += 1
270270
if global_step > 0 and global_step % save_steps == 0:
271271
self._save_checkpoint(epoch, global_step, optimizer, scheduler, checkpoint_dir, label_mappings)
272-
272+
273273
if val_loader: self.evaluate(val_loader)
274274
self._save_checkpoint(epoch + 1, global_step, optimizer, scheduler, checkpoint_dir, label_mappings)
275275
steps_to_skip = 0
@@ -325,13 +325,13 @@ def main():
325325
logger.info("--- Starting Model Training ---")
326326

327327
task_configs, label_mappings, checkpoint_to_load = {}, {}, None
328-
328+
329329
if args.resume:
330330
latest_checkpoint_path = os.path.join(args.checkpoint_dir, 'latest_checkpoint.pt')
331331
if os.path.exists(latest_checkpoint_path):
332332
logger.info(f"Resuming training from checkpoint: {latest_checkpoint_path}")
333333
checkpoint_to_load = torch.load(latest_checkpoint_path, map_location=device)
334-
334+
335335
task_configs = checkpoint_to_load.get('task_configs')
336336
label_mappings = checkpoint_to_load.get('label_mappings')
337337

@@ -340,10 +340,10 @@ def main():
340340
logger.warning("Cannot safely resume. Starting a fresh training run.")
341341
args.resume = False
342342
checkpoint_to_load = None
343-
task_configs = {}
343+
task_configs = {}
344344
else:
345345
logger.info("Loaded model configuration from checkpoint.")
346-
346+
347347
else:
348348
logger.warning(f"Resume flag is set, but no checkpoint found in '{args.checkpoint_dir}'. Starting fresh run.")
349349
args.resume = False
@@ -358,12 +358,12 @@ def main():
358358
task_configs["pii"] = {"num_classes": len(label_mappings["pii"]["label_mapping"]["label_to_idx"]), "weight": 3.0}
359359
if "jailbreak" in label_mappings:
360360
task_configs["jailbreak"] = {"num_classes": len(label_mappings["jailbreak"]["label_mapping"]["label_to_idx"]), "weight": 2.0}
361-
361+
362362
if not task_configs:
363363
logger.error("No tasks configured. Exiting."); return
364364

365365
logger.info(f"Final task configurations: {task_configs}")
366-
366+
367367
model = MultitaskBertModel(base_model_name, task_configs)
368368

369369
if args.resume and checkpoint_to_load:
@@ -379,9 +379,9 @@ def main():
379379
active_label_mappings = label_mappings if (args.resume and label_mappings) else final_label_mappings
380380

381381
trainer = MultitaskTrainer(model, tokenizer, task_configs, device)
382-
382+
383383
logger.info(f"Total training samples: {len(train_samples)}")
384-
384+
385385
trainer.train(
386386
train_samples, val_samples, active_label_mappings,
387387
num_epochs=10, batch_size=16,
@@ -390,12 +390,11 @@ def main():
390390
save_steps=args.save_steps,
391391
checkpoint_to_load=checkpoint_to_load
392392
)
393-
393+
394394
trainer.save_model(output_path)
395395
with open(os.path.join(output_path, "label_mappings.json"), "w") as f:
396396
json.dump(active_label_mappings, f, indent=2)
397397
logger.info("Multitask training completed!")
398398

399399
if __name__ == "__main__":
400400
main()
401-

0 commit comments

Comments
 (0)