Skip to content

torchao: add int8; quanto: add NF4; torch compile fixes + ability to compile optim #986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c38c109
torchao: fp8/autoquant
Sep 24, 2024
721ce10
update deps
Sep 24, 2024
0d8b7cf
gc_collect should be called before clear on torch and after for mps
Sep 24, 2024
44c82ff
mps: disable gpu quantisation since it does not work
Sep 24, 2024
6118768
add int8-torchao level for mps support
Sep 24, 2024
9fcd748
update to use the newer ao api and move cuda restriction to fp8
Sep 24, 2024
ed35338
allow training the full model in a quantised state
Sep 24, 2024
296e3dd
return the modified model
Sep 25, 2024
50f59d6
Merge branch 'main' into feature/torchao
bghira Sep 26, 2024
26ee784
torchao: low-precision optims need fp32 gradients
Sep 26, 2024
de4b563
torchao: cpu optimiser offload, which also does not work
Sep 26, 2024
979d298
torchao: fix int8 training by monkeypatching the broken method
Sep 27, 2024
b989ff3
update with int8 nvidia fix
Sep 27, 2024
cfb6e62
nvidia lock file update
Sep 27, 2024
a1413a1
fix torch compile validation arg
Sep 28, 2024
6124933
update error msg
Sep 28, 2024
a4d3e6b
fix int8 again, as we cannot use filter_fn on the whole model
Sep 28, 2024
4f28ebb
remove fp8 and auto
Sep 28, 2024
8bf7107
update message for loading module
Sep 28, 2024
f40d67c
torchao: rename quantoise -> quantise_model
Sep 28, 2024
3def7c5
quanto: add nf4 support
Sep 28, 2024
3aa1925
update optimum-quanto for nf4 support
Sep 28, 2024
1017e8a
update options doc contents, adding quantisation notes
Sep 28, 2024
f6d770c
Update helpers/training/custom_schedule.py
bghira Sep 28, 2024
284af19
update quanto fp8 for marlin gemm kernel and auto switch from fp8 to …
Sep 28, 2024
509a25e
reformat files that were missed earlier
Sep 28, 2024
34ad1fd
reorganise options doc
Sep 28, 2024
e97183c
disable cpu offloaded optim
Sep 28, 2024
a43ff90
dynamo optimisation for flux transformer, always use fp32 rope
Sep 28, 2024
9bf32b8
remove old optimiser init
Sep 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 103 additions & 18 deletions OPTIONS.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions helpers/caching/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ def reclaim_memory():
import torch

if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

if torch.backends.mps.is_available():
torch.mps.empty_cache()
torch.mps.synchronize()

gc.collect()
gc.collect()
68 changes: 58 additions & 10 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from helpers.training import quantised_precision_levels
from helpers.training.optimizer_param import (
is_optimizer_deprecated,
is_optimizer_bf16,
is_optimizer_grad_fp32,
map_deprecated_optimizer_parameter,
optimizer_choices,
)
Expand Down Expand Up @@ -1148,6 +1148,39 @@ def get_argument_parser():
" For example, `--optimizer_config=decouple_lr=True,weight_decay=0.01`."
),
)
parser.add_argument(
"--optimizer_cpu_offload_method",
choices=["none", "torchao"],
default="none",
help=(
"When loading an optimiser, a CPU offload mechanism can be used. Currently, no offload is used by default, and only torchao is supported."
),
)
parser.add_argument(
"--optimizer_offload_gradients",
action="store_true",
default=False,
help=(
"When creating a CPU-offloaded optimiser, the gradients can be offloaded to the CPU to save more memory."
),
)
parser.add_argument(
"--fuse_optimizer",
action="store_true",
default=False,
help=(
"When creating a CPU-offloaded optimiser, the fused optimiser could be used to save on memory, while running slightly slower."
),
)
parser.add_argument(
"--optimizer_torch_compile",
action="store_true",
default=False,
help=(
"When using a CPU-offloaded optimiser, we can torch.compile() it and save some time using a compiled graph."
" This option will not work on Apple MPS devices, and may not work on all systems."
),
)
parser.add_argument(
"--optimizer_beta1",
type=float,
Expand Down Expand Up @@ -1282,8 +1315,8 @@ def get_argument_parser():
)
parser.add_argument(
"--validation_torch_compile",
type=str,
default="false",
action="store_true",
default=False,
help=(
"Supply `--validation_torch_compile=true` to enable the use of torch.compile() on the validation pipeline."
" For some setups, torch.compile() may error out. This is dependent on PyTorch version, phase of the moon,"
Expand Down Expand Up @@ -1984,6 +2017,21 @@ def parse_cmdline_args(input_args=None):
raise ValueError(
f"Model is not using bf16 precision, but the optimizer {chosen_optimizer} requires it."
)
if is_optimizer_grad_fp32(args.optimizer):
print(
"[WARNING] Using a low-precision optimizer that requires fp32 gradients. Training will run more slowly."
)
if args.gradient_precision != "fp32":
print(
f"[WARNING] Overriding gradient_precision to 'fp32' for {args.optimizer} optimizer."
)
args.gradient_precision = "fp32"
else:
if args.gradient_precision == "fp32":
print(
f"[WARNING] Overriding gradient_precision to 'unmodified' for {args.optimizer} optimizer, as fp32 gradients are not required."
)
args.gradient_precision = "unmodified"

if torch.backends.mps.is_available():
if (
Expand All @@ -2001,6 +2049,12 @@ def parse_cmdline_args(input_args=None):
)
sys.exit(1)

if args.quantize_via == "accelerator":
error_log(
"MPS does not benefit from models being quantized on the accelerator device. Overriding --quantize_via to 'cpu'."
)
args.quantize_via = "cpu"

if (
args.max_train_steps is not None
and args.max_train_steps > 0
Expand Down Expand Up @@ -2091,10 +2145,6 @@ def parse_cmdline_args(input_args=None):

if args.metadata_update_interval < 60:
raise ValueError("Metadata update interval must be at least 60 seconds.")
if args.validation_torch_compile == "true":
args.validation_torch_compile = True
else:
args.validation_torch_compile = False

if args.model_family == "sd3":
args.pretrained_vae_model_name_or_path = None
Expand Down Expand Up @@ -2247,9 +2297,7 @@ def parse_cmdline_args(input_args=None):
)
args.disable_accelerator = os.environ.get("SIMPLETUNER_DISABLE_ACCELERATOR", False)

if "lora" not in args.model_type:
args.base_model_precision = "no_change"
elif "lycoris" == args.lora_type.lower():
if "lycoris" == args.lora_type.lower():
from lycoris import create_lycoris

if args.lycoris_config is None:
Expand Down
4 changes: 4 additions & 0 deletions helpers/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
# "fp4-bnb",
# "fp8-bnb",
"fp8-quanto",
"nf4-quanto",
"int8-quanto",
"int4-quanto",
"int2-quanto",
# currently does not work.
# "fp8-torchao",
"int8-torchao",
]

image_file_extensions = set(["jpg", "jpeg", "png", "webp", "bmp", "tiff", "tif"])
Expand Down
1 change: 1 addition & 0 deletions helpers/training/custom_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def get_polynomial_decay_schedule_with_warmup(

"""

print(f"Optimizer: {optimizer}")
lr_init = optimizer.defaults["lr"]
if not (float(lr_init) > float(lr_end)):
raise ValueError(
Expand Down
91 changes: 91 additions & 0 deletions helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@
except:
pass

try:
from torchao.prototype.low_bit_optim import (
AdamW8bit as AOAdamW8Bit,
Adam4bit as AOAdamW4Bit,
AdamFp8 as AOAdamFp8,
AdamWFp8 as AOAdamWFp8,
CPUOffloadOptimizer as AOCPUOffloadOptimizer,
)

if torch.backends.mps.is_available():
import torch._dynamo

torch._dynamo.config.suppress_errors = True
except Exception as e:
print("You need torchao installed for its low-precision optimizers.")
raise e

try:
import optimi

Expand All @@ -36,6 +53,46 @@
},
"class": AdamWBF16,
},
"ao-adamw8bit": {
"gradient_precision": "bf16",
"precision": "any",
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-6,
},
"class": AOAdamW8Bit,
},
"ao-adamw4bit": {
"gradient_precision": "bf16",
"precision": "any",
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-6,
},
"class": AOAdamW4Bit,
},
"ao-adamfp8": {
"gradient_precision": "bf16",
"precision": "any",
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-6,
},
"class": AOAdamFp8,
},
"ao-adamwfp8": {
"gradient_precision": "bf16",
"precision": "any",
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-6,
},
"class": AOAdamWFp8,
},
"adamw_schedulefree": {
"precision": "any",
"override_lr_scheduler": True,
Expand Down Expand Up @@ -276,6 +333,40 @@ def is_optimizer_bf16(optimizer: str) -> bool:
return False


def is_optimizer_grad_fp32(optimizer: str) -> bool:
optimizer_precision = optimizer_choices.get(optimizer, {}).get(
"gradient_precision", None
)
if optimizer_precision == "fp32":
return True
return False


def cpu_offload_optimizer(
params_to_optimize,
optimizer_cls,
optimizer_parameters: dict,
offload_gradients: bool = True,
fused: bool = True,
offload_mechanism: str = None,
):
if not offload_mechanism or offload_mechanism == "none":
return optimizer_cls(params_to_optimize, **optimizer_parameters)
if offload_mechanism != "torchao":
raise ValueError(
f"Unknown CPU optimiser offload mechanism: {offload_mechanism}"
)

if offload_gradients:
optimizer_parameters["offload_gradients"] = offload_gradients
if fused:
optimizer_parameters["fused"] = fused

optimizer_parameters["optimizer_class"] = optimizer_cls

return AOCPUOffloadOptimizer(params_to_optimize, **optimizer_parameters)


def determine_optimizer_class_with_config(
args, use_deepspeed_optimizer, is_quantized, enable_adamw_bf16
) -> tuple:
Expand Down
Loading
Loading