Skip to content

Commit

Permalink
implement optional gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
hturki committed Feb 1, 2023
1 parent e979f89 commit bc3eed0
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions nerfstudio/engine/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Type, Optional

import torch
from torch.cuda.amp.grad_scaler import GradScaler
Expand All @@ -36,13 +36,15 @@ class OptimizerConfig(base_config.PrintableConfig):
_target: Type = torch.optim.Adam
lr: float = 0.0005
eps: float = 1e-08
max_norm: Optional[float] = None

# TODO: somehow make this more generic. i dont like the idea of overriding the setup function
# but also not sure how to go about passing things into predefined torch objects.
def setup(self, params) -> Any:
"""Returns the instantiated object using the config."""
kwargs = vars(self).copy()
kwargs.pop("_target")
kwargs.pop("max_norm")
return self._target(params, **kwargs)


Expand Down Expand Up @@ -73,9 +75,11 @@ def __init__(self, config: Dict[str, Any], param_groups: Dict[str, List[Paramete
self.config = config
self.optimizers = {}
self.schedulers = {}
self.parameters = {}
for param_group_name, params in param_groups.items():
lr_init = config[param_group_name]["optimizer"].lr
self.optimizers[param_group_name] = config[param_group_name]["optimizer"].setup(params=params)
self.parameters[param_group_name] = params
if config[param_group_name]["scheduler"]:
self.schedulers[param_group_name] = config[param_group_name]["scheduler"].setup(
optimizer=self.optimizers[param_group_name], lr_init=lr_init
Expand Down Expand Up @@ -109,13 +113,20 @@ def optimizer_scaler_step_all(self, grad_scaler: GradScaler) -> None:
Args:
grad_scaler: GradScaler to use
"""
for _, optimizer in self.optimizers.items():
for param_group, optimizer in self.optimizers.items():
max_norm = self.config[param_group]["optimizer"].max_norm
if max_norm is not None:
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters[param_group], max_norm)
grad_scaler.step(optimizer)

def optimizer_step_all(self):
"""Run step for all optimizers."""
for _, optimizer in self.optimizers.items():
for param_group, optimizer in self.optimizers.items():
# note that they key is the parameter name
max_norm = self.config[param_group]["optimizer"].max_norm
if max_norm is not None:
torch.nn.utils.clip_grad_norm_(self.parameters[param_group], max_norm)
optimizer.step()

def scheduler_step_all(self, step: int) -> None:
Expand Down

0 comments on commit bc3eed0

Please sign in to comment.