Skip to content

Commit

Permalink
Support for complex LR scheduler configuration (#125)
Browse files Browse the repository at this point in the history
* Support for complex LR scheduler configuration

* Pass in optimizer only if cls exists

* Support for complex LR scheduler configuration

* Pass in optimizer only if cls exists

* Done

* Template config change

* Reformat
hrukalive authored Aug 4, 2023
1 parent 3bd250d commit 7f4f515
Showing 5 changed files with 61 additions and 35 deletions.
38 changes: 15 additions & 23 deletions basics/base_task.py
Original file line number Diff line number Diff line change
@@ -289,27 +289,24 @@ def on_validation_epoch_end(self):

# noinspection PyMethodMayBeStatic
def build_scheduler(self, optimizer):
from utils import build_object_from_config
from utils import build_lr_scheduler_from_config

scheduler_args = hparams['lr_scheduler_args']
assert scheduler_args['scheduler_cls'] != ''
scheduler = build_object_from_config(
scheduler_args['scheduler_cls'],
optimizer,
**scheduler_args
)
scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args)
return scheduler

# noinspection PyMethodMayBeStatic
def build_optimizer(self, model):
from utils import build_object_from_config
from utils import build_object_from_class_name

optimizer_args = hparams['optimizer_args']
assert optimizer_args['optimizer_cls'] != ''
if 'beta1' in optimizer_args and 'beta2' in optimizer_args and 'betas' not in optimizer_args:
optimizer_args['betas'] = (optimizer_args['beta1'], optimizer_args['beta2'])
optimizer = build_object_from_config(
optimizer = build_object_from_class_name(
optimizer_args['optimizer_cls'],
torch.optim.Optimizer,
model.parameters(),
**optimizer_args
)
@@ -497,27 +494,22 @@ def on_load_checkpoint(self, checkpoint):
param_group[k] = v
if 'initial_lr' in param_group and param_group['initial_lr'] != optimizer_args['lr']:
rank_zero_info(
f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}')
f'| Overriding optimizer parameter initial_lr from checkpoint: {param_group["initial_lr"]} -> {optimizer_args["lr"]}'
)
param_group['initial_lr'] = optimizer_args['lr']

if checkpoint.get('lr_schedulers', None):
assert checkpoint.get('optimizer_states', False)
schedulers = checkpoint['lr_schedulers']
assert len(schedulers) == 1 # only support one scheduler
scheduler = schedulers[0]
for k, v in scheduler_args.items():
if k in scheduler and scheduler[k] != v:
rank_zero_info(f'| Overriding scheduler parameter {k} from checkpoint: {scheduler[k]} -> {v}')
scheduler[k] = v
scheduler['base_lrs'] = [group['initial_lr'] for group in checkpoint['optimizer_states'][0]['param_groups']]
new_lrs = simulate_lr_scheduler(
assert len(checkpoint['lr_schedulers']) == 1 # only support one scheduler
checkpoint['lr_schedulers'][0] = simulate_lr_scheduler(
optimizer_args, scheduler_args,
last_epoch=scheduler['last_epoch'],
step_count=checkpoint['global_step'],
num_param_groups=len(checkpoint['optimizer_states'][0]['param_groups'])
)
scheduler['_last_lr'] = new_lrs
for param_group, new_lr in zip(checkpoint['optimizer_states'][0]['param_groups'], new_lrs):
for param_group, new_lr in zip(
checkpoint['optimizer_states'][0]['param_groups'],
checkpoint['lr_schedulers'][0]['_last_lr'],
):
if param_group['lr'] != new_lr:
rank_zero_info(
f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
rank_zero_info(f'| Overriding optimizer parameter lr from checkpoint: {param_group["lr"]} -> {new_lr}')
param_group['lr'] = new_lr
1 change: 0 additions & 1 deletion configs/base.yaml
Original file line number Diff line number Diff line change
@@ -55,7 +55,6 @@ optimizer_args:
weight_decay: 0
lr_scheduler_args:
scheduler_cls: torch.optim.lr_scheduler.StepLR
warmup_steps: 2000
step_size: 50000
gamma: 0.5
clip_grad_norm: 1
1 change: 1 addition & 0 deletions configs/templates/config_variance.yaml
Original file line number Diff line number Diff line change
@@ -67,6 +67,7 @@ lambda_var_loss: 1.0
optimizer_args:
lr: 0.0006
lr_scheduler_args:
scheduler_cls: torch.optim.lr_scheduler.StepLR
step_size: 12000
gamma: 0.75
max_batch_frames: 80000
52 changes: 42 additions & 10 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -260,29 +260,61 @@ def num_params(model, print_out=True, model_name="model"):
return parameters


def build_object_from_config(cls_str, *args, **kwargs):
def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs):
import importlib

pkg = ".".join(cls_str.split(".")[:-1])
cls_name = cls_str.split(".")[-1]
cls_type = getattr(importlib.import_module(pkg), cls_name)
if parent_cls is not None:
assert issubclass(cls_type, parent_cls), f'| {cls_type} is not subclass of {parent_cls}.'

return cls_type(*args, **filter_kwargs(kwargs, cls_type))


def simulate_lr_scheduler(optimizer_args, scheduler_args, last_epoch=-1, num_param_groups=1):
optimizer = build_object_from_config(
def build_lr_scheduler_from_config(optimizer, scheduler_args):
def helper(params):
if isinstance(params, list):
return [helper(s) for s in params]
elif isinstance(params, dict):
resolved = {k: helper(v) for k, v in params.items()}
if 'cls' in resolved:
if (
resolved["cls"] == "torch.optim.lr_scheduler.ChainedScheduler"
and scheduler_args["scheduler_cls"] == "torch.optim.lr_scheduler.SequentialLR"
):
raise ValueError(f"ChainedScheduler cannot be part of a SequentialLR.")
resolved['optimizer'] = optimizer
obj = build_object_from_class_name(
resolved['cls'],
torch.optim.lr_scheduler.LRScheduler,
**resolved
)
return obj
return resolved
else:
return params
resolved = helper(scheduler_args)
resolved['optimizer'] = optimizer
return build_object_from_class_name(
scheduler_args['scheduler_cls'],
torch.optim.lr_scheduler.LRScheduler,
**resolved
)


def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1):
optimizer = build_object_from_class_name(
optimizer_args['optimizer_cls'],
torch.optim.Optimizer,
[{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)],
**optimizer_args
)
scheduler = build_object_from_config(scheduler_args['scheduler_cls'], optimizer, last_epoch=last_epoch,
**scheduler_args)

if hasattr(scheduler, '_get_closed_form_lr'):
return scheduler._get_closed_form_lr()
else:
return scheduler.get_lr()
scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args)
scheduler.optimizer._step_count = 1
for _ in range(step_count):
scheduler.step()
return scheduler.state_dict()


def remove_suffix(string: str, suffix: str):
4 changes: 3 additions & 1 deletion utils/training_utils.py
Original file line number Diff line number Diff line change
@@ -316,7 +316,9 @@ def get_metrics(self, trainer, model):
items['steps'] = str(trainer.global_step)
for k, v in items.items():
if isinstance(v, float):
if 0.001 <= v < 10:
if np.isnan(v):
items[k] = 'nan'
elif 0.001 <= v < 10:
items[k] = np.format_float_positional(v, unique=True, precision=5, trim='-')
elif 0.00001 <= v < 0.001:
if len(np.format_float_positional(v, unique=True, precision=8, trim='-')) > 8:

0 comments on commit 7f4f515

Please sign in to comment.