Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flux and sd3 could use uniform sampling instead of beta or sigmoid #1129

Merged
merged 4 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def get_argument_parser():
" which has improved results in short experiments. Thanks to @mhirki for the contribution."
),
)
parser.add_argument(
"--flux_use_uniform_schedule",
action="store_true",
help=(
"Whether or not to use a uniform schedule with Flux instead of sigmoid."
" Using uniform sampling may help preserve more capabilities from the base model."
" Some tasks may not benefit from this."
),
)
parser.add_argument(
"--flux_use_beta_schedule",
action="store_true",
Expand Down
52 changes: 48 additions & 4 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def _guidance_rescale(args):
return f"\n guidance_rescale={args.validation_guidance_rescale},"


def _skip_layers(args):
if (
args.model_family.lower() not in ["sd3"]
or args.validation_guidance_skip_layers is None
):
return ""
return f"\n skip_guidance_layers={args.validation_guidance_skip_layers},"


def _validation_resolution(args):
if args.validation_resolution == "" or args.validation_resolution is None:
return f"width=1024,\n" f" height=1024,"
Expand Down Expand Up @@ -185,7 +194,7 @@ def code_example(args, repo_id: str = None):
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
{_validation_resolution(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)},{_skip_layers(args)}
).images[0]
image.save("output.png", format="PNG")
```
Expand Down Expand Up @@ -249,17 +258,52 @@ def flux_schedule_info(args):
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_attention_masked_training:
output_args.append("flux_attention_masked_training")
if args.model_type == "lora" and args.lora_type == "standard":
if (
args.model_type == "lora"
and args.lora_type == "standard"
and args.flux_lora_target is not None
):
output_args.append(f"flux_lora_target={args.flux_lora_target}")
output_str = (
f" (flux parameters={output_args})"
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str


def sd3_schedule_info(args):
if args.model_family.lower() != "sd3":
return ""
output_args = []
if args.flux_schedule_auto_shift:
output_args.append("flux_schedule_auto_shift")
if args.flux_schedule_shift is not None:
output_args.append(f"shift={args.flux_schedule_shift}")
if args.flux_use_beta_schedule:
output_args.append(f"flux_beta_schedule_alpha={args.flux_beta_schedule_alpha}")
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_use_uniform_schedule:
output_args.append(f"flux_use_uniform_schedule")
# if args.model_type == "lora" and args.lora_type == "standard":
# output_args.append(f"flux_lora_target={args.flux_lora_target}")
output_str = (
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str


def model_schedule_info(args):
if args.model_family == "flux":
return flux_schedule_info(args)
if args.model_family == "sd3":
return sd3_schedule_info(args)


def save_model_card(
repo_id: str,
images=None,
Expand Down Expand Up @@ -384,7 +428,7 @@ def save_model_card(
- Micro-batch size: {StateTracker.get_args().train_batch_size}
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{flux_schedule_info(args=StateTracker.get_args())}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
Expand Down
7 changes: 6 additions & 1 deletion helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,7 @@ def train(self):
if self.config.flow_matching:
if (
not self.config.flux_fast_schedule
and not self.config.flux_use_beta_schedule
and not any([self.config.flux_use_beta_schedule, self.config.flux_use_uniform_schedule])
):
# imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF
# also used by: https://github.com/XLabs-AI/x-flux/tree/main
Expand All @@ -2197,6 +2197,11 @@ def train(self):
sigmas = apply_flux_schedule_shift(
self.config, self.noise_scheduler, sigmas, noise
)
elif self.config.flux_use_uniform_schedule:
sigmas = torch.rand((bsz,), device=self.accelerator.device)
sigmas = apply_flux_schedule_shift(
self.config, self.noise_scheduler, sigmas, noise
)
elif self.config.flux_use_beta_schedule:
alpha = self.config.flux_beta_schedule_alpha
beta = self.config.flux_beta_schedule_beta
Expand Down
Loading
Loading