Skip to content

Commit

Permalink
[training] fixes to the quantization training script and add AdEMAMix…
Browse files Browse the repository at this point in the history
… optimizer as an option (#9806)

* fixes

* more fixes.
  • Loading branch information
sayakpaul committed Dec 23, 2024
1 parent bbbd1c0 commit dfbe972
Showing 1 changed file with 31 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,19 @@ def parse_args(input_args=None):
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
choices=["AdamW", "Prodigy", "AdEMAMix"],
)

parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--use_8bit_ademamix",
action="store_true",
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
)

parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
Expand Down Expand Up @@ -820,16 +825,15 @@ def load_model_hook(models, input_dir):
params_to_optimize = [transformer_parameters_with_lr]

# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
args.optimizer = "adamw"

if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
f"set to {args.optimizer.lower()}"
)

Expand All @@ -853,6 +857,20 @@ def load_model_hook(models, input_dir):
eps=args.adam_epsilon,
)

elif args.optimizer.lower() == "ademamix":
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
)
if args.use_8bit_ademamix:
optimizer_class = bnb.optim.AdEMAMix8bit
else:
optimizer_class = bnb.optim.AdEMAMix

optimizer = optimizer_class(params_to_optimize)

if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
Expand All @@ -868,7 +886,6 @@ def load_model_hook(models, input_dir):

optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
Expand Down Expand Up @@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)

vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
Expand Down Expand Up @@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
Expand All @@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)

Expand Down

0 comments on commit dfbe972

Please sign in to comment.