Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Trainer: Add adamw 4bit optimizer #31865

Merged
merged 13 commits into from
Aug 22, 2024
18 changes: 18 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_xla_available,
is_torchao_available,
logging,
strtobool,
)
Expand Down Expand Up @@ -1434,6 +1435,23 @@ def optimizer_hook(param):
optimizer_cls = Lomo

optimizer_kwargs.update({"model": model})
elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) <= version.parse(
"0.3.1"
):
raise ImportError(
"You need to have `torchao>0.3.1` in order to use torch 4-bit optimizers. "
"Install it with `pip install https://github.com/pytorch/ao.git`"
)
if not version.parse(importlib.metadata.version("torch")) < version.parse("2.3"):
raise ImportError(
"You need to have `torch>2.3` in order to use torch 4-bit optimizers. "
"Install it with `pip install --upgrade torch`"
)
from torchao.prototype.low_bit_optim import AdamW4bit

optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class OptimizerNames(ExplicitEnum):
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
is_torch_tpu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
is_torchaudio_available,
is_torchdistx_available,
is_torchdynamo_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_eetq_available = _is_package_available("eetq")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_torchao_available = _is_package_available("torchao")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -337,6 +338,10 @@ def is_torchvision_available():
return _torchvision_available


def is_torchao_available():
return _torchao_available


def is_galore_torch_available():
return _galore_torch_available

Expand Down
11 changes: 11 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
is_torchao_available,
is_torchdistx_available,
)
from transformers.utils.hp_naming import TrialShortNamer
Expand Down Expand Up @@ -4174,6 +4175,16 @@ def hp_name(trial):
dict(default_adam_kwargs, **default_anyprecision_kwargs),
)
)
if is_torchao_available():
import torchao

optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADAMW_TORCH_4BIT, output_dir="None"),
torchao.prototype.low_bit_optim.AdamW4bit,
default_adam_kwargs,
)
)


@require_torch
Expand Down
Loading