Skip to content

Commit 48ec36e

Browse files
authored
Add support for gradient clipping (#1331)
implement optional gradient clipping
1 parent 53a830a commit 48ec36e

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

nerfstudio/engine/optimizers.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
from dataclasses import dataclass
21-
from typing import Any, Dict, List, Type
21+
from typing import Any, Dict, List, Optional, Type
2222

2323
import torch
2424
from torch.cuda.amp.grad_scaler import GradScaler
@@ -36,13 +36,15 @@ class OptimizerConfig(base_config.PrintableConfig):
3636
_target: Type = torch.optim.Adam
3737
lr: float = 0.0005
3838
eps: float = 1e-08
39+
max_norm: Optional[float] = None
3940

4041
# TODO: somehow make this more generic. i dont like the idea of overriding the setup function
4142
# but also not sure how to go about passing things into predefined torch objects.
4243
def setup(self, params) -> Any:
4344
"""Returns the instantiated object using the config."""
4445
kwargs = vars(self).copy()
4546
kwargs.pop("_target")
47+
kwargs.pop("max_norm")
4648
return self._target(params, **kwargs)
4749

4850

@@ -73,9 +75,11 @@ def __init__(self, config: Dict[str, Any], param_groups: Dict[str, List[Paramete
7375
self.config = config
7476
self.optimizers = {}
7577
self.schedulers = {}
78+
self.parameters = {}
7679
for param_group_name, params in param_groups.items():
7780
lr_init = config[param_group_name]["optimizer"].lr
7881
self.optimizers[param_group_name] = config[param_group_name]["optimizer"].setup(params=params)
82+
self.parameters[param_group_name] = params
7983
if config[param_group_name]["scheduler"]:
8084
self.schedulers[param_group_name] = config[param_group_name]["scheduler"].setup(
8185
optimizer=self.optimizers[param_group_name], lr_init=lr_init
@@ -109,13 +113,20 @@ def optimizer_scaler_step_all(self, grad_scaler: GradScaler) -> None:
109113
Args:
110114
grad_scaler: GradScaler to use
111115
"""
112-
for _, optimizer in self.optimizers.items():
116+
for param_group, optimizer in self.optimizers.items():
117+
max_norm = self.config[param_group]["optimizer"].max_norm
118+
if max_norm is not None:
119+
grad_scaler.unscale_(optimizer)
120+
torch.nn.utils.clip_grad_norm_(self.parameters[param_group], max_norm)
113121
grad_scaler.step(optimizer)
114122

115123
def optimizer_step_all(self):
116124
"""Run step for all optimizers."""
117-
for _, optimizer in self.optimizers.items():
125+
for param_group, optimizer in self.optimizers.items():
118126
# note that they key is the parameter name
127+
max_norm = self.config[param_group]["optimizer"].max_norm
128+
if max_norm is not None:
129+
torch.nn.utils.clip_grad_norm_(self.parameters[param_group], max_norm)
119130
optimizer.step()
120131

121132
def scheduler_step_all(self, step: int) -> None:

0 commit comments

Comments
 (0)