From c4fef7e37d0b0cea963f71e597c42374131854a9 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 8 Nov 2024 14:28:34 +0000 Subject: [PATCH 1/3] flux and sd3 could use uniform sampling instead of beta or sigmoid --- helpers/configuration/cmd_args.py | 9 +++++++++ helpers/training/trainer.py | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index c51635df..14e1b470 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -149,6 +149,15 @@ def get_argument_parser(): " which has improved results in short experiments. Thanks to @mhirki for the contribution." ), ) + parser.add_argument( + "--flux_use_uniform_schedule", + action="store_true", + help=( + "Whether or not to use a uniform schedule with Flux instead of sigmoid." + " Using uniform sampling may help preserve more capabilities from the base model." + " Some tasks may not benefit from this." + ), + ) parser.add_argument( "--flux_use_beta_schedule", action="store_true", diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index a5e70d44..4a0eb031 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -2185,7 +2185,7 @@ def train(self): if self.config.flow_matching: if ( not self.config.flux_fast_schedule - and not self.config.flux_use_beta_schedule + and not any([self.config.flux_use_beta_schedule, self.config.flux_use_uniform_schedule]) ): # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF # also used by: https://github.com/XLabs-AI/x-flux/tree/main @@ -2197,6 +2197,11 @@ def train(self): sigmas = apply_flux_schedule_shift( self.config, self.noise_scheduler, sigmas, noise ) + elif self.config.flux_use_uniform_schedule: + sigmas = torch.rand((bsz,), device=self.accelerator.device) + sigmas = apply_flux_schedule_shift( + self.config, self.noise_scheduler, sigmas, noise + ) elif self.config.flux_use_beta_schedule: alpha = self.config.flux_beta_schedule_alpha beta = self.config.flux_beta_schedule_beta From e4e5097e555a4073c51342e4719aa8c3b9458ce1 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 8 Nov 2024 09:21:19 -0600 Subject: [PATCH 2/3] sd3: model card detail expansion --- helpers/publishing/metadata.py | 52 +++++- tests/test_model_card.py | 287 +++++++++++++++++++++++++++++++++ 2 files changed, 335 insertions(+), 4 deletions(-) create mode 100644 tests/test_model_card.py diff --git a/helpers/publishing/metadata.py b/helpers/publishing/metadata.py index f41600a2..b0031157 100644 --- a/helpers/publishing/metadata.py +++ b/helpers/publishing/metadata.py @@ -153,6 +153,15 @@ def _guidance_rescale(args): return f"\n guidance_rescale={args.validation_guidance_rescale}," +def _skip_layers(args): + if ( + args.model_family.lower() not in ["sd3"] + or args.validation_guidance_skip_layers is None + ): + return "" + return f"\n skip_guidance_layers={args.validation_guidance_skip_layers}," + + def _validation_resolution(args): if args.validation_resolution == "" or args.validation_resolution is None: return f"width=1024,\n" f" height=1024," @@ -185,7 +194,7 @@ def code_example(args, repo_id: str = None): num_inference_steps={args.validation_num_inference_steps}, generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826), {_validation_resolution(args)} - guidance_scale={args.validation_guidance},{_guidance_rescale(args)} + guidance_scale={args.validation_guidance},{_guidance_rescale(args)},{_skip_layers(args)} ).images[0] image.save("output.png", format="PNG") ``` @@ -249,10 +258,38 @@ def flux_schedule_info(args): output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}") if args.flux_attention_masked_training: output_args.append("flux_attention_masked_training") - if args.model_type == "lora" and args.lora_type == "standard": + if ( + args.model_type == "lora" + and args.lora_type == "standard" + and args.flux_lora_target is not None + ): output_args.append(f"flux_lora_target={args.flux_lora_target}") output_str = ( - f" (flux parameters={output_args})" + f" (extra parameters={output_args})" + if output_args + else " (no special parameters set)" + ) + + return output_str + + +def sd3_schedule_info(args): + if args.model_family.lower() != "sd3": + return "" + output_args = [] + if args.flux_schedule_auto_shift: + output_args.append("flux_schedule_auto_shift") + if args.flux_schedule_shift is not None: + output_args.append(f"shift={args.flux_schedule_shift}") + if args.flux_use_beta_schedule: + output_args.append(f"flux_beta_schedule_alpha={args.flux_beta_schedule_alpha}") + output_args.append(f"flux_beta_schedule_beta={args.flux_beta_schedule_beta}") + if args.flux_use_uniform_schedule: + output_args.append(f"flux_use_uniform_schedule") + # if args.model_type == "lora" and args.lora_type == "standard": + # output_args.append(f"flux_lora_target={args.flux_lora_target}") + output_str = ( + f" (extra parameters={output_args})" if output_args else " (no special parameters set)" ) @@ -260,6 +297,13 @@ def flux_schedule_info(args): return output_str +def model_schedule_info(args): + if args.model_family == "flux": + return flux_schedule_info(args) + if args.model_family == "sd3": + return sd3_schedule_info(args) + + def save_model_card( repo_id: str, images=None, @@ -384,7 +428,7 @@ def save_model_card( - Micro-batch size: {StateTracker.get_args().train_batch_size} - Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps} - Number of GPUs: {StateTracker.get_accelerator().num_processes} -- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{flux_schedule_info(args=StateTracker.get_args())} +- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())} - Rescaled betas zero SNR: {StateTracker.get_args().rescale_betas_zero_snr} - Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''} - Precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'} diff --git a/tests/test_model_card.py b/tests/test_model_card.py new file mode 100644 index 00000000..0744e887 --- /dev/null +++ b/tests/test_model_card.py @@ -0,0 +1,287 @@ +import unittest +from unittest.mock import MagicMock, patch +import os +import json + +# Assuming the functions are in a module named 'metadata.py' +from helpers.publishing.metadata import ( + _negative_prompt, + _torch_device, + _model_imports, + _model_load, + _validation_resolution, + _skip_layers, + _guidance_rescale, +) +from helpers.publishing.metadata import * + +# For demonstration purposes, I'll redefine the functions here. +# In your actual test file, import them from your module as shown above. + + +class TestMetadataFunctions(unittest.TestCase): + def setUp(self): + # Mock the args object + self.args = MagicMock() + self.args.lora_type = "standard" + self.args.model_type = "lora" + self.args.model_family = "sdxl" + self.args.validation_prompt = "A test prompt" + self.args.validation_negative_prompt = "A negative prompt" + self.args.validation_num_inference_steps = 50 + self.args.validation_guidance = 7.5 + self.args.validation_guidance_rescale = 0.7 + self.args.validation_resolution = "512x512" + self.args.pretrained_model_name_or_path = "test-model" + self.args.output_dir = "test-output" + self.args.lora_rank = 4 + self.args.lora_alpha = 1.0 + self.args.lora_dropout = 0.0 + self.args.lora_init_type = "kaiming_uniform" + self.args.model_card_note = "Test note" + self.args.validation_using_datasets = False + self.args.flow_matching_loss = "flow-matching" + self.args.flux_fast_schedule = False + self.args.flux_schedule_auto_shift = False + self.args.flux_schedule_shift = None + self.args.flux_guidance_value = None + self.args.flux_guidance_min = None + self.args.flux_guidance_max = None + self.args.flux_use_beta_schedule = False + self.args.flux_beta_schedule_alpha = None + self.args.flux_beta_schedule_beta = None + self.args.flux_attention_masked_training = False + self.args.flux_use_uniform_schedule = False + self.args.flux_lora_target = None + self.args.validation_guidance_skip_layers = None + self.args.validation_seed = 1234 + self.args.validation_noise_scheduler = "ddim" + self.args.model_card_safe_for_work = True + self.args.learning_rate = 1e-4 + self.args.max_grad_norm = 1.0 + self.args.train_batch_size = 4 + self.args.gradient_accumulation_steps = 1 + self.args.optimizer = "AdamW" + self.args.optimizer_config = "" + self.args.mixed_precision = "fp16" + self.args.base_model_precision = "no_change" + self.args.enable_xformers_memory_efficient_attention = False + + def test_model_imports(self): + self.args.lora_type = "standard" + self.args.model_type = "lora" + expected_output = "import torch\nfrom diffusers import DiffusionPipeline" + output = _model_imports(self.args) + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.lora_type = "lycoris" + output = _model_imports(self.args) + self.assertIn("from lycoris import create_lycoris_from_weights", output) + + def test_model_load(self): + self.args.pretrained_model_name_or_path = "pretrained-model" + self.args.output_dir = "output-dir" + self.args.lora_type = "standard" + self.args.model_type = "lora" + + with patch( + "helpers.publishing.metadata.StateTracker.get_hf_username", + return_value="testuser", + ): + output = _model_load(self.args, repo_id="repo-id") + self.assertIn("pipeline.load_lora_weights", output) + self.assertIn("adapter_id = 'testuser/repo-id'", output) + + self.args.lora_type = "lycoris" + output = _model_load(self.args) + self.assertIn("pytorch_lora_weights.safetensors", output) + + def test_torch_device(self): + output = _torch_device() + expected_output = "'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'" + self.assertEqual(output.strip(), expected_output.strip()) + + def test_negative_prompt(self): + self.args.model_family = "sdxl" + output = _negative_prompt(self.args) + expected_output = "negative_prompt = 'A negative prompt'" + self.assertEqual(output.strip(), expected_output.strip()) + + output_in_call = _negative_prompt(self.args, in_call=True) + self.assertIn("negative_prompt=negative_prompt", output_in_call) + + def test_guidance_rescale(self): + self.args.model_family = "sdxl" + output = _guidance_rescale(self.args) + expected_output = "\n guidance_rescale=0.7," + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.model_family = "flux" + output = _guidance_rescale(self.args) + self.assertEqual(output.strip(), "") + + def test_skip_layers(self): + self.args.model_family = "sd3" + self.args.validation_guidance_skip_layers = 2 + output = _skip_layers(self.args) + expected_output = "\n skip_guidance_layers=2," + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.model_family = "sdxl" + output = _skip_layers(self.args) + self.assertEqual(output.strip(), "") + + def test_validation_resolution(self): + self.args.validation_resolution = "512x512" + output = _validation_resolution(self.args) + expected_output = "width=512,\n height=512," + self.assertEqual(output.strip(), expected_output.strip()) + + self.args.validation_resolution = "" + output = _validation_resolution(self.args) + expected_output = "width=1024,\n height=1024," + self.assertEqual(output.strip(), expected_output.strip()) + + def test_code_example(self): + with patch( + "helpers.publishing.metadata._model_imports", + return_value="import torch\nfrom diffusers import DiffusionPipeline", + ): + with patch( + "helpers.publishing.metadata._model_load", return_value="pipeline = ..." + ): + with patch( + "helpers.publishing.metadata._torch_device", return_value="'cuda'" + ): + with patch( + "helpers.publishing.metadata._negative_prompt", + return_value="negative_prompt = 'A negative prompt'", + ): + with patch( + "helpers.publishing.metadata._validation_resolution", + return_value="width=512,\n height=512,", + ): + output = code_example(self.args) + self.assertIn("import torch", output) + self.assertIn("pipeline = ...", output) + self.assertIn("pipeline.to('cuda')", output) + + def test_model_type(self): + self.args.model_type = "lora" + self.args.lora_type = "standard" + output = model_type(self.args) + self.assertEqual(output, "standard PEFT LoRA") + + self.args.lora_type = "lycoris" + output = model_type(self.args) + self.assertEqual(output, "LyCORIS adapter") + + self.args.model_type = "full" + output = model_type(self.args) + self.assertEqual(output, "full rank finetune") + + def test_lora_info(self): + self.args.model_type = "lora" + self.args.lora_type = "standard" + output = lora_info(self.args) + self.assertIn("LoRA Rank: 4", output) + + self.args.lora_type = "lycoris" + # Mocking the file reading + lycoris_config = {"key": "value"} + with patch( + "builtins.open", + unittest.mock.mock_open(read_data=json.dumps(lycoris_config)), + ): + output = lora_info(self.args) + self.assertIn('"key": "value"', output) + + def test_model_card_note(self): + output = model_card_note(self.args) + self.assertIn("Test note", output) + + self.args.model_card_note = "" + output = model_card_note(self.args) + self.assertEqual(output.strip(), "") + + def test_flux_schedule_info(self): + self.args.model_family = "flux" + output = flux_schedule_info(self.args) + self.assertIn("(no special parameters set)", output) + + self.args.flux_fast_schedule = True + output = flux_schedule_info(self.args) + self.assertIn("flux_fast_schedule", output) + + def test_sd3_schedule_info(self): + self.args.model_family = "sd3" + output = sd3_schedule_info(self.args) + self.assertIn("(no special parameters set)", output) + + self.args.flux_schedule_auto_shift = True + output = sd3_schedule_info(self.args) + self.assertIn("flux_schedule_auto_shift", output) + + def test_model_schedule_info(self): + with patch( + "helpers.publishing.metadata.flux_schedule_info", return_value="flux info" + ): + with patch( + "helpers.publishing.metadata.sd3_schedule_info", return_value="sd3 info" + ): + self.args.model_family = "flux" + output = model_schedule_info(self.args) + self.assertEqual(output, "flux info") + + self.args.model_family = "sd3" + output = model_schedule_info(self.args) + self.assertEqual(output, "sd3 info") + + def test_save_model_card(self): + # Mocking StateTracker methods + with patch( + "helpers.publishing.metadata.StateTracker.get_model_family", + return_value="sdxl", + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_data_backends", + return_value={}, + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_epoch", return_value=1 + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_global_step", + return_value=1000, + ): + with patch( + "helpers.publishing.metadata.StateTracker.get_accelerator", + return_value=MagicMock(num_processes=1), + ): + with patch( + "helpers.publishing.metadata.code_example", + return_value="code example", + ): + with patch( + "builtins.open", unittest.mock.mock_open() + ) as mock_file: + save_model_card( + repo_id="test-repo", + images=None, + base_model="test-base-model", + train_text_encoder=True, + prompt="Test prompt", + validation_prompts=["Test prompt"], + validation_shortnames=["shortname"], + repo_folder="test-folder", + ) + # Ensure the README.md was written + mock_file.assert_called_with( + os.path.join("test-folder", "README.md"), + "w", + encoding="utf-8", + ) + + +if __name__ == "__main__": + unittest.main() From 73c344119e1a311a1d1438f580c574ff07115678 Mon Sep 17 00:00:00 2001 From: bghira Date: Fri, 8 Nov 2024 09:23:33 -0600 Subject: [PATCH 3/3] remove boilerplate template text --- tests/test_model_card.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 0744e887..b9c07d33 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -3,7 +3,6 @@ import os import json -# Assuming the functions are in a module named 'metadata.py' from helpers.publishing.metadata import ( _negative_prompt, _torch_device, @@ -15,9 +14,6 @@ ) from helpers.publishing.metadata import * -# For demonstration purposes, I'll redefine the functions here. -# In your actual test file, import them from your module as shown above. - class TestMetadataFunctions(unittest.TestCase): def setUp(self):