Skip to content

Commit

Permalink
Add AdamW to CPUOffloadOptimizer default (#742)
Browse files Browse the repository at this point in the history
add default
  • Loading branch information
gau-nernst authored Aug 24, 2024
1 parent eb47c93 commit af68031
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions torchao/prototype/low_bit_optim/cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
from typing import Type

import torch
from torch.optim.optimizer import Optimizer
from torch.optim.optimizer import Optimizer, ParamsT

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4


class CPUOffloadOptimizer:
def __init__(self, params, optimizer_class: Type[Optimizer], *, offload_gradients: bool = False, **kwargs) -> None:
def __init__(
self,
params: ParamsT,
optimizer_class: Type[Optimizer] = torch.optim.AdamW,
*,
offload_gradients: bool = False,
**kwargs,
) -> None:
"""Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state.
Optimizer step will be done on CPU.
Args
params: a list of parameters or parameter groups.
optimizer_class: constructor of the base optimizer.
optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`.
offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation.
kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`.
"""
# default to fused CPU AdamW
if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs:
kwargs.update(fused=True)

param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
Expand Down

0 comments on commit af68031

Please sign in to comment.