-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[training] fixes to the quantization training script and add AdEMAMix optimizer as an option #9806
Conversation
@@ -1059,7 +1076,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |||
) | |||
|
|||
# handle guidance | |||
if transformer.config.guidance_embeds: | |||
if unwrap_model(transformer).config.guidance_embeds: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that things are compatible with DeepSpeed.
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) | ||
|
||
latent_image_ids = FluxPipeline._prepare_latent_image_ids( | ||
model_input.shape[0], | ||
model_input.shape[2], | ||
model_input.shape[3], | ||
model_input.shape[2] // 2, | ||
model_input.shape[3] // 2, | ||
accelerator.device, | ||
weight_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follows what we do in the Flux LoRA scripts.
height=model_input.shape[2] * vae_scale_factor, | ||
width=model_input.shape[3] * vae_scale_factor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! did you get a chance to play with AdEMAMix? should we consider adding it to the other scripts as well?
examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
Show resolved
Hide resolved
Testing the memory requirement as we speak. Will report back. |
else: | ||
optimizer_class = bnb.optim.AdEMAMix | ||
|
||
optimizer = optimizer_class(params_to_optimize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we support betas
and weight_decay
here?
We could use the existing args like we did for prodigy, i.e.
optimizer = optimizer_class(params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Umm I didn't want to actually to keep the separations of concern very clear. We could maybe revisit if the community finds the optimizer worth the go?
@linoytsaban seems identical to me: Let's maybe wait a bit before we propagate this to canonical scripts? Meanwhile can I merge this PR? Internal thread: https://huggingface.slack.com/archives/C04NNCRFYUQ/p1730299681912129
With ...
image = pipeline(
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
).images[0] |
@sayakpaul yeah sounds good to me :) let's 🛳️ |
… optimizer as an option (#9806) * fixes * more fixes.
… optimizer as an option (#9806) * fixes * more fixes.
What does this PR do?
AdEMAMix
: https://hf.co/papers/2409.03137.