Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix learning rate schedulers in PyTorch 2.0 (closes #3202) #3207

Merged
merged 1 commit into from
May 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions pyro/optim/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ def __call__(self, params: Union[List, ValuesView], *args, **kwargs) -> None:
if self.grad_clip[p] is not None:
self.grad_clip[p](p)

if isinstance(
self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler
) or isinstance(
self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau
if (
hasattr(torch.optim.lr_scheduler, "_LRScheduler")
and isinstance(
self.optim_objs[p], torch.optim.lr_scheduler._LRScheduler
)
or hasattr(torch.optim.lr_scheduler, "LRScheduler")
and isinstance(self.optim_objs[p], torch.optim.lr_scheduler.LRScheduler)
or isinstance(
self.optim_objs[p], torch.optim.lr_scheduler.ReduceLROnPlateau
)
):
# if optim object was a scheduler, perform an optimizer step
self.optim_objs[p].optimizer.step(*args, **kwargs)
Expand Down