Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[training] fixes to the quantization training script and add AdEMAMix optimizer as an option #9806

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,19 @@ def parse_args(input_args=None):
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
choices=["AdamW", "Prodigy", "AdEMAMix"],
)

parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--use_8bit_ademamix",
action="store_true",
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
)

parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
Expand Down Expand Up @@ -820,16 +825,15 @@ def load_model_hook(models, input_dir):
params_to_optimize = [transformer_parameters_with_lr]

# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
args.optimizer = "adamw"

if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
f"set to {args.optimizer.lower()}"
)

Expand All @@ -853,6 +857,20 @@ def load_model_hook(models, input_dir):
eps=args.adam_epsilon,
)

elif args.optimizer.lower() == "ademamix":
linoytsaban marked this conversation as resolved.
Show resolved Hide resolved
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
)
if args.use_8bit_ademamix:
optimizer_class = bnb.optim.AdEMAMix8bit
else:
optimizer_class = bnb.optim.AdEMAMix

optimizer = optimizer_class(params_to_optimize)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we support betas and weight_decay here?
We could use the existing args like we did for prodigy, i.e.

optimizer = optimizer_class(params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm I didn't want to actually to keep the separations of concern very clear. We could maybe revisit if the community finds the optimizer worth the go?


if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
Expand All @@ -868,7 +886,6 @@ def load_model_hook(models, input_dir):

optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
Expand Down Expand Up @@ -1020,12 +1037,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)

vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
Comment on lines +1040 to 1047
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follows what we do in the Flux LoRA scripts.

)
Expand Down Expand Up @@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that things are compatible with DeepSpeed.

guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
Expand All @@ -1082,8 +1099,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
Comment on lines +1102 to +1103
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

vae_scale_factor=vae_scale_factor,
)

Expand Down
Loading