Skip to content

Commit

Permalink
Merge pull request #1129 from bghira/feature/flow-matching-uniform-sa…
Browse files Browse the repository at this point in the history
…mpling

flux and sd3 could use uniform sampling instead of beta or sigmoid
  • Loading branch information
bghira authored Nov 9, 2024
2 parents 25cad1c + 0b34f24 commit 7146c5e
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 5 deletions.
9 changes: 9 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def get_argument_parser():
" which has improved results in short experiments. Thanks to @mhirki for the contribution."
),
)
parser.add_argument(
"--flux_use_uniform_schedule",
action="store_true",
help=(
"Whether or not to use a uniform schedule with Flux instead of sigmoid."
" Using uniform sampling may help preserve more capabilities from the base model."
" Some tasks may not benefit from this."
),
)
parser.add_argument(
"--flux_use_beta_schedule",
action="store_true",
Expand Down
52 changes: 48 additions & 4 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def _guidance_rescale(args):
return f"\n guidance_rescale={args.validation_guidance_rescale},"


def _skip_layers(args):
if (
args.model_family.lower() not in ["sd3"]
or args.validation_guidance_skip_layers is None
):
return ""
return f"\n skip_guidance_layers={args.validation_guidance_skip_layers},"


def _validation_resolution(args):
if args.validation_resolution == "" or args.validation_resolution is None:
return f"width=1024,\n" f" height=1024,"
Expand Down Expand Up @@ -185,7 +194,7 @@ def code_example(args, repo_id: str = None):
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
{_validation_resolution(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)},{_skip_layers(args)}
).images[0]
image.save("output.png", format="PNG")
```
Expand Down Expand Up @@ -249,17 +258,52 @@ def flux_schedule_info(args):
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_attention_masked_training:
output_args.append("flux_attention_masked_training")
if args.model_type == "lora" and args.lora_type == "standard":
if (
args.model_type == "lora"
and args.lora_type == "standard"
and args.flux_lora_target is not None
):
output_args.append(f"flux_lora_target={args.flux_lora_target}")
output_str = (
f" (flux parameters={output_args})"
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str


def sd3_schedule_info(args):
if args.model_family.lower() != "sd3":
return ""
output_args = []
if args.flux_schedule_auto_shift:
output_args.append("flux_schedule_auto_shift")
if args.flux_schedule_shift is not None:
output_args.append(f"shift={args.flux_schedule_shift}")
if args.flux_use_beta_schedule:
output_args.append(f"flux_beta_schedule_alpha={args.flux_beta_schedule_alpha}")
output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}")
if args.flux_use_uniform_schedule:
output_args.append(f"flux_use_uniform_schedule")
# if args.model_type == "lora" and args.lora_type == "standard":
# output_args.append(f"flux_lora_target={args.flux_lora_target}")
output_str = (
f" (extra parameters={output_args})"
if output_args
else " (no special parameters set)"
)

return output_str


def model_schedule_info(args):
if args.model_family == "flux":
return flux_schedule_info(args)
if args.model_family == "sd3":
return sd3_schedule_info(args)


def save_model_card(
repo_id: str,
images=None,
Expand Down Expand Up @@ -384,7 +428,7 @@ def save_model_card(
- Micro-batch size: {StateTracker.get_args().train_batch_size}
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{flux_schedule_info(args=StateTracker.get_args())}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
Expand Down
7 changes: 6 additions & 1 deletion helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,7 @@ def train(self):
if self.config.flow_matching:
if (
not self.config.flux_fast_schedule
and not self.config.flux_use_beta_schedule
and not any([self.config.flux_use_beta_schedule, self.config.flux_use_uniform_schedule])
):
# imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF
# also used by: https://github.com/XLabs-AI/x-flux/tree/main
Expand All @@ -2197,6 +2197,11 @@ def train(self):
sigmas = apply_flux_schedule_shift(
self.config, self.noise_scheduler, sigmas, noise
)
elif self.config.flux_use_uniform_schedule:
sigmas = torch.rand((bsz,), device=self.accelerator.device)
sigmas = apply_flux_schedule_shift(
self.config, self.noise_scheduler, sigmas, noise
)
elif self.config.flux_use_beta_schedule:
alpha = self.config.flux_beta_schedule_alpha
beta = self.config.flux_beta_schedule_beta
Expand Down
Loading

0 comments on commit 7146c5e

Please sign in to comment.