Skip to content

Commit

Permalink
Allow different LRs for parameter groups
Browse files Browse the repository at this point in the history
  • Loading branch information
Ming Du committed Jun 6, 2024
1 parent cf8ed6f commit 45ff69e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
15 changes: 12 additions & 3 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import collections
import dataclasses
from typing import Any, Callable, Optional, Union, Tuple
from typing import *
import json
import os
import re
Expand Down Expand Up @@ -322,12 +322,21 @@ class TrainingConfig(Config):
or processes.
"""

optimizer: Any = torch.optim.Adam
"""String of optimizer name or the handle of a subclass of torch.optim.Optimizer"""
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
"""The optimizer class. Should be given as the handle of a subclass of torch.optim.Optimizer."""

optimizer_params: dict = dataclasses.field(default_factory=dict)
"""Optimizer parameters."""

multi_optimizer_param_dicts: Optional[Sequence[Dict]] = None
"""
The optimizer uses different learning rates for different parameters if this is provided.
It should be a list of dictionaries as described in
https://pytorch.org/docs/stable/optim.html#per-parameter-options.
However, the code to get trainable parameters in the "params" keys should be given
as a string, where the model object should be referenced as `self.get_model_object()`.
"""

model_save_dir: str = '.'
"""Directory to save trained models."""

Expand Down
29 changes: 23 additions & 6 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,19 +814,36 @@ def build_optimizer(self):
"one.".format(self.model.__class__))
else:
trainable_params = self.model.parameters()
if isinstance(self.configs.optimizer, str):
if self.configs.optimizer == 'adam':
self.optimizer = torch.optim.Adam(trainable_params, lr=self.learning_rate)
else:
if self.configs.multi_optimizer_param_dicts is None:
self.optimizer = self.configs.optimizer(trainable_params, lr=self.learning_rate,
**self.configs.optimizer_params)
else:
# Construct per-parameter dicts
perparam_dicts = []
for i, d in enumerate(self.configs.multi_optimizer_param_dicts):
d_copy = d.copy()
d_copy['params'] = eval(d['params'])
if 'lr' in d_copy.keys():
d_copy['lr'] = d_copy['lr'] * self.num_processes
perparam_dicts.append(d_copy)
self.optimizer = self.configs.optimizer(perparam_dicts, lr=self.learning_rate,
**self.configs.optimizer_params)

def build_scheduler(self):
self.iterations_per_epoch = len(self.training_dataset) / self.all_proc_batch_size
self.iterations_per_epoch = np.ceil(self.iterations_per_epoch)
step_size = 6 * self.iterations_per_epoch
self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.learning_rate / 10,
max_lr=self.learning_rate, step_size_up=step_size,
if self.configs.multi_optimizer_param_dicts is None:
base_lr=self.learning_rate * 0.1
max_lr=self.learning_rate
else:
base_lr = []
max_lr = []
for d in self.optimizer.param_groups:
base_lr.append(d['lr'] * 0.1)
max_lr.append(self.learning_rate)
self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=base_lr,
max_lr=max_lr, step_size_up=step_size,
cycle_momentum=False, mode='triangular2')

def build_amp(self):
Expand Down

0 comments on commit 45ff69e

Please sign in to comment.