From ea6dfb5e8f8a037f19920bcdd9c41a69c0154faf Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 9 Sep 2024 07:41:55 -0600 Subject: [PATCH 01/12] VAECache: improve startup speed for extremely large datasets --- helpers/caching/vae.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/helpers/caching/vae.py b/helpers/caching/vae.py index aff01acd..1dcc7eb6 100644 --- a/helpers/caching/vae.py +++ b/helpers/caching/vae.py @@ -357,7 +357,7 @@ def _list_cached_images(self): def discover_unprocessed_files(self, directory: str = None): """Identify files that haven't been processed yet.""" - all_image_files = StateTracker.get_image_files(data_backend_id=self.id) + all_image_files = set(StateTracker.get_image_files(data_backend_id=self.id)) existing_cache_files = set( StateTracker.get_vae_cache_files(data_backend_id=self.id) ) @@ -367,7 +367,6 @@ def discover_unprocessed_files(self, directory: str = None): try: n = self._image_filename_from_vaecache_filename(cache_file) already_cached_images.append(n) - # print(f"Mapping: {n} -> {cache_file}") except Exception as e: logger.error( f"Could not find image path for cache file {cache_file}: {e}" @@ -375,9 +374,9 @@ def discover_unprocessed_files(self, directory: str = None): continue # Identify unprocessed files - self.local_unprocessed_files = [ - file for file in all_image_files if file not in already_cached_images - ] + self.local_unprocessed_files = list( + set(all_image_files) - set(already_cached_images) + ) return self.local_unprocessed_files From 05a3c3707c6dc7e2a8031bb86752cb38e38b9f43 Mon Sep 17 00:00:00 2001 From: bghira Date: Mon, 9 Sep 2024 08:08:39 -0600 Subject: [PATCH 02/12] sine: update to use math.sin instead --- helpers/training/custom_schedule.py | 62 ++++++++++------------------- 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/helpers/training/custom_schedule.py b/helpers/training/custom_schedule.py index a2d6fe3b..b54a3473 100644 --- a/helpers/training/custom_schedule.py +++ b/helpers/training/custom_schedule.py @@ -432,36 +432,33 @@ def print_lr(self, is_verbose, group, lr, epoch=None): class Sine(LRScheduler): def __init__( - self, - optimizer, - T_0, - steps_per_epoch=-1, - T_mult=1, - eta_min=0, - last_step=-1, - verbose=False, + self, optimizer, T_0, T_mult=1, eta_min=0, last_step=-1, verbose=False ): if T_0 <= 0 or not isinstance(T_0, int): raise ValueError( - f"Sine learning rate expects to use warmup steps as its interval. Expected positive integer T_0, but got {T_0}" + f"Sine learning rate expects positive integer T_0, but got {T_0}" ) if T_mult < 1 or not isinstance(T_mult, int): raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}") + self.optimizer = optimizer self.T_0 = T_0 - self.steps_per_epoch = steps_per_epoch - self.T_i = T_0 self.T_mult = T_mult self.eta_min = eta_min + self.T_i = T_0 self.T_cur = last_step - super(Sine, self).__init__(optimizer, last_step, verbose) + self.last_epoch = last_step + self.base_lrs = [group["lr"] for group in optimizer.param_groups] + self.verbose = verbose + self._last_lr = self.base_lrs + self.total_steps = 0 # Track total steps for a continuous wave def get_lr(self): + # Calculate learning rates using a continuous sine function based on total steps lrs = [ self.eta_min + (base_lr - self.eta_min) - * (1 - math.cos(math.pi / 2 + math.pi * self.T_cur / self.T_i)) - / 2 + * (0.5 * (1 + math.sin(math.pi * self.total_steps / self.T_0))) for base_lr in self.base_lrs ] return lrs @@ -469,40 +466,23 @@ def get_lr(self): def step(self, step=None): if step is None: step = self.last_epoch + 1 - self.T_cur = step % self.T_i - - if step != 0 and step % self.T_i == 0: - self.T_i *= self.T_mult + self.total_steps = step # Use total steps instead of resetting per interval self.last_epoch = step - # This context manager ensures that the learning rate is updated correctly - with _enable_get_lr_call(self): - # Loop through each parameter group and its corresponding learning rate - for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): - param_group, lr = data - # Update the learning rate for this parameter group - # We use math.floor to truncate the precision to avoid numerical issues - param_group["lr"] = math.floor(lr * 1e9) / 1e9 - # Print the updated learning rate if verbose mode is enabled - self.print_lr(self.verbose, i, lr, step) + for i, (param_group, lr) in enumerate( + zip(self.optimizer.param_groups, self.get_lr()) + ): + param_group["lr"] = math.floor(lr * 1e9) / 1e9 + self.print_lr(self.verbose, i, lr, step) - # Update the last learning rate values for each parameter group self._last_lr = [group["lr"] for group in self.optimizer.param_groups] def print_lr(self, is_verbose, group, lr, epoch=None): - """Display the current learning rate.""" if is_verbose: - if epoch is None: - print( - "Adjusting learning rate" - " of group {} to {:.8e}.".format(group, lr) - ) - else: - epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch - print( - "Epoch {}: adjusting learning rate" - " of group {} to {:.8e}.".format(epoch_str, group, lr) - ) + epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch + print( + f"Epoch {epoch_str}: adjusting learning rate of group {group} to {lr:.8e}." + ) from diffusers.optimization import get_scheduler From 6f482b99be9b21cbcc7a39b9882185d1a8bbb25f Mon Sep 17 00:00:00 2001 From: Ana Els <77629566+anae-git@users.noreply.github.com> Date: Mon, 9 Sep 2024 23:12:43 -0400 Subject: [PATCH 03/12] Update FLUX.md Fix unfortunate typo that will confuse people --- documentation/quickstart/FLUX.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 38c8d74a..34a5bd23 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -305,7 +305,7 @@ Then, create a `datasets` directory: ```bash mkdir -p datasets pushd datasets - huggingface-cli download --repo_type=dataset bghira/pseudo-camera-10k --local-dir=pseudo-camera-10k + huggingface-cli download --repo-type=dataset bghira/pseudo-camera-10k --local-dir=pseudo-camera-10k mkdir dreambooth-subject # place your images into dreambooth-subject/ now popd From c37cec1beead3bff4cc54b3a02aeaff3a374a261 Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 10 Sep 2024 06:27:59 -0600 Subject: [PATCH 04/12] Fix #953 by updating stable_diffusion_3 to sd3 for v1.0 config changes --- configure.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/configure.py b/configure.py index 12591857..51b1c53e 100644 --- a/configure.py +++ b/configure.py @@ -16,10 +16,10 @@ "sdxl", "pixart_sigma", "kolors", - "stable_diffusion_3", + "sd3", "stable_diffusion_legacy", ], - "lora": ["flux", "sdxl", "kolors", "stable_diffusion_3", "stable_diffusion_legacy"], + "lora": ["flux", "sdxl", "kolors", "sd3", "stable_diffusion_legacy"], "controlnet": ["sdxl", "stable_diffusion_legacy"], } @@ -29,7 +29,7 @@ "pixart_sigma": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", "kolors": "kwai-kolors/kolors-diffusers", "terminus": "ptx0/terminus-xl-velocity-v2", - "stable_diffusion_3": "stabilityai/stable-diffusion-3-medium-diffusers", + "sd3": "stabilityai/stable-diffusion-3-medium-diffusers", } default_cfg = { @@ -38,7 +38,7 @@ "pixart_sigma": 3.4, "kolors": 5.0, "terminus": 8.0, - "stable_diffusion_3": 5.0, + "sd3": 5.0, } model_labels = { From 28cbb34b349393c5f2ee3872c36e4101cb21018b Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 10 Sep 2024 07:29:58 -0600 Subject: [PATCH 05/12] Fix #961 by instantiating the batch fetcher with the step count and incrementing locally instead of globally --- helpers/data_backend/factory.py | 9 +++++---- helpers/training/trainer.py | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/helpers/data_backend/factory.py b/helpers/data_backend/factory.py index 3ab365c8..b4d9f9fe 100644 --- a/helpers/data_backend/factory.py +++ b/helpers/data_backend/factory.py @@ -1299,10 +1299,11 @@ def random_dataloader_iterator(step, backends: dict): class BatchFetcher: - def __init__(self, max_size=10, datasets={}): + def __init__(self, step, max_size=10, datasets={}): self.queue = queue.Queue(max_size) self.datasets = datasets self.keep_running = True + self.step = step def start_fetching(self): thread = threading.Thread(target=self.fetch_responses) @@ -1310,14 +1311,13 @@ def start_fetching(self): return thread def fetch_responses(self): - global step prefetch_log_debug("Launching retrieval thread.") while self.keep_running: if self.queue.qsize() < self.queue.maxsize: prefetch_log_debug( f"Queue size: {self.queue.qsize()}. Fetching more data." ) - self.queue.put(random_dataloader_iterator(self.datasets)) + self.queue.put(random_dataloader_iterator(self.step, self.datasets)) if self.queue.qsize() >= self.queue.maxsize: prefetch_log_debug("Completed fetching data. Queue is full.") continue @@ -1325,7 +1325,8 @@ def fetch_responses(self): time.sleep(0.5) prefetch_log_debug("Exiting retrieval thread.") - def next_response(self): + def next_response(self, step: int): + self.step = step if self.queue.empty(): prefetch_log_debug("Queue is empty. Waiting for data.") while self.queue.empty(): diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index fd090e80..97f47538 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -512,7 +512,8 @@ def init_data_backend(self): message_level="critical", ) - return False + raise e + self.init_validation_prompts() # We calculate the number of steps per epoch by dividing the number of images by the effective batch divisor. # Gradient accumulation steps mean that we only update the model weights every /n/ steps. @@ -1585,6 +1586,7 @@ def train(self): self.bf = BatchFetcher( datasets=train_backends, max_size=self.config.dataloader_prefetch_qlen, + step=step, ) if fetch_thread is not None: fetch_thread.join() From 77e0bd5ad78933bfa7ee974172d682e028ddad7c Mon Sep 17 00:00:00 2001 From: bghira Date: Tue, 10 Sep 2024 08:32:45 -0600 Subject: [PATCH 06/12] (#962) initialise self.optimizer instead of optimizer for deepspeed --- helpers/training/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index 97f47538..dca93610 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -936,7 +936,7 @@ def init_optimizer(self): ) if self.config.use_deepspeed_optimizer: - optimizer = optimizer_class(self.params_to_optimize) + self.optimizer = optimizer_class(self.params_to_optimize) else: logger.info( f"Optimizer arguments, weight_decay={self.config.adam_weight_decay} eps={self.config.adam_epsilon}, extra_arguments={extra_optimizer_args}" From 74d319c0bc41660dc59de3cc31f7e564ed6d887e Mon Sep 17 00:00:00 2001 From: Alon Burg Date: Fri, 13 Sep 2024 13:23:02 +0300 Subject: [PATCH 07/12] fix bucket worker not waiting for all queue worker to finish --- helpers/metadata/backends/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index 76e4984e..5b9bc0ce 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -257,7 +257,11 @@ def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = Fals ncols=100, miniters=int(len(new_files) / 100), ) as pbar: - while any(worker.is_alive() for worker in workers): + while any(worker.is_alive() for worker in workers) or \ + not tqdm_queue.empty() or \ + not aspect_ratio_bucket_indices_queue.empty() or \ + not metadata_updates_queue.empty() or \ + not written_files_queue.empty(): current_time = time.time() while not tqdm_queue.empty(): pbar.update(tqdm_queue.get()) From 4a633a3e731154f7f9e90a62665a5f7c80dd0012 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 14 Sep 2024 12:36:18 -0600 Subject: [PATCH 08/12] fix #972 by unwrapping model --- helpers/training/trainer.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index dca93610..75d1a93d 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -822,26 +822,40 @@ def init_post_load_freeze(self): if self.unet is not None: logger.info("Applying BitFit freezing strategy to the U-net.") - self.unet = apply_bitfit_freezing(self.unet, self.config) + self.unet = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.unet), self.config + ) if self.transformer is not None: logger.warning( "Training DiT models with BitFit is not yet tested, and unexpected results may occur." ) - self.transformer = apply_bitfit_freezing(self.transformer, self.config) + self.transformer = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.transformer), self.config + ) if self.config.gradient_checkpointing: if self.unet is not None: - self.unet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.unet + ).enable_gradient_checkpointing() if self.transformer is not None and self.config.model_family != "smoldit": - self.transformer.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.transformer + ).enable_gradient_checkpointing() if self.config.controlnet: - self.controlnet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.controlnet + ).enable_gradient_checkpointing() if ( hasattr(self.config, "train_text_encoder") and self.config.train_text_encoder ): - self.text_encoder_1.gradient_checkpointing_enable() - self.text_encoder_2.gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_1 + ).gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_2 + ).gradient_checkpointing_enable() def _recalculate_training_steps(self): # Scheduler and math around the number of training steps. From a997ec49c03f4aa731a79e072ade222890ab7c35 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 14 Sep 2024 12:38:44 -0600 Subject: [PATCH 09/12] fix nonetype reference when ctrl+c --- train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train.py b/train.py index cb63532b..b3c72a18 100644 --- a/train.py +++ b/train.py @@ -8,8 +8,6 @@ logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) if __name__ == "__main__": - global bf - bf = None trainer = None try: import multiprocessing @@ -64,4 +62,4 @@ print(e) print(traceback.format_exc()) if trainer is not None and trainer.bf is not None: - bf.stop_fetching() + trainer.bf.stop_fetching() From be9a80d264bf4fdc338d47b834f95bb8dbb84903 Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 14 Sep 2024 14:03:18 -0600 Subject: [PATCH 10/12] tests: add trainer unit/regression tests --- helpers/metadata/backends/discovery.py | 4 +- helpers/training/save_hooks.py | 4 +- tests/test_metadata_backend.py | 2 +- tests/test_trainer.py | 383 +++++++++++++++++++++++++ 4 files changed, 388 insertions(+), 5 deletions(-) create mode 100644 tests/test_trainer.py diff --git a/helpers/metadata/backends/discovery.py b/helpers/metadata/backends/discovery.py index 99481ffa..7e3959bf 100644 --- a/helpers/metadata/backends/discovery.py +++ b/helpers/metadata/backends/discovery.py @@ -122,11 +122,11 @@ def reload_cache(self, set_config: bool = True): dict: The cache data. """ # Query our DataBackend to see whether the cache file exists. - logger.info(f"Checking for cache file: {self.cache_file}") + logger.debug(f"Checking for cache file: {self.cache_file}") if self.data_backend.exists(self.cache_file): try: # Use our DataBackend to actually read the cache file. - logger.info("Pulling cache file from storage") + logger.debug("Pulling cache file from storage") cache_data_raw = self.data_backend.read(self.cache_file) cache_data = json.loads(cache_data_raw) except Exception as e: diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 4819cfdc..35197fa5 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -168,8 +168,8 @@ def __init__( if args.controlnet: self.denoiser_class = ControlNetModel self.denoiser_subdir = "controlnet" - logger.info(f"Denoiser class set to: {self.denoiser_class.__name__}.") - logger.info(f"Pipeline class set to: {self.pipeline_class.__name__}.") + logger.debug(f"Denoiser class set to: {self.denoiser_class.__name__}.") + logger.debug(f"Pipeline class set to: {self.pipeline_class.__name__}.") self.ema_model_cls = None self.ema_model_subdir = None diff --git a/tests/test_metadata_backend.py b/tests/test_metadata_backend.py index 0de318d2..16584329 100644 --- a/tests/test_metadata_backend.py +++ b/tests/test_metadata_backend.py @@ -22,7 +22,7 @@ def setUp(self): self.image_path_str = "test_image.jpg" self.instance_data_dir = "/some/fake/path" - self.cache_file = "/some/fake/cache.json" + self.cache_file = "/some/fake/cache" self.metadata_file = "/some/fake/metadata.json" StateTracker.set_args(MagicMock()) # Overload cache file with json: diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 00000000..5093a16a --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,383 @@ +# test_trainer.py + +import unittest +from unittest.mock import Mock, patch, MagicMock +import torch, os + +os.environ["SIMPLETUNER_LOG_LEVEL"] = "CRITICAL" +from helpers.training.trainer import Trainer + + +class TestTrainer(unittest.TestCase): + @patch("helpers.training.trainer.load_config") + @patch("helpers.training.trainer.safety_check") + @patch( + "helpers.training.trainer.load_scheduler_from_args", + return_value=(Mock(), None, Mock()), + ) + @patch("helpers.training.state_tracker.StateTracker") + @patch( + "helpers.training.state_tracker.StateTracker.set_model_family", + return_value=True, + ) + @patch("torch.set_num_threads") + @patch("helpers.training.trainer.Accelerator") + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + def test_config_to_obj( + self, + mock_misc_init, + mock_parse_args, + mock_accelerator, + mock_set_num_threads, + mock_set_model_family, + mock_state_tracker, + mock_load_scheduler_from_args, + mock_safety_check, + mock_load_config, + ): + trainer = Trainer() + config_dict = {"a": 1, "b": 2} + config_obj = trainer._config_to_obj(config_dict) + self.assertEqual(config_obj.a, 1) + self.assertEqual(config_obj.b, 2) + + config_none = trainer._config_to_obj(None) + self.assertIsNone(config_none) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.set_seed") + def test_init_seed_with_value(self, mock_set_seed, mock_parse_args, mock_misc_init): + trainer = Trainer() + trainer.config = Mock(seed=42, seed_for_each_device=False) + trainer.init_seed() + mock_set_seed.assert_called_with(42, False) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.set_seed") + def test_init_seed_none(self, mock_set_seed, mock_parse_args, mock_misc_init): + trainer = Trainer() + trainer.config = Mock(seed=None, seed_for_each_device=False) + trainer.init_seed() + mock_set_seed.assert_not_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("torch.cuda.is_available", return_value=True) + @patch("torch.cuda.memory_allocated", return_value=1024**3) + def test_stats_memory_used_cuda( + self, mock_memory_allocated, mock_is_available, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + memory_used = trainer.stats_memory_used() + self.assertEqual(memory_used, 1.0) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("torch.cuda.is_available", return_value=False) + @patch("torch.backends.mps.is_available", return_value=True) + @patch("torch.mps.current_allocated_memory", return_value=1024**3) + def test_stats_memory_used_mps( + self, + mock_current_allocated_memory, + mock_mps_is_available, + mock_cuda_is_available, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + memory_used = trainer.stats_memory_used() + self.assertEqual(memory_used, 1.0) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("torch.cuda.is_available", return_value=False) + @patch("torch.backends.mps.is_available", return_value=False) + @patch("helpers.training.trainer.logger") + def test_stats_memory_used_none( + self, + mock_logger, + mock_mps_is_available, + mock_cuda_is_available, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + memory_used = trainer.stats_memory_used() + self.assertEqual(memory_used, 0) + mock_logger.warning.assert_called_with( + "CUDA, ROCm, or Apple MPS not detected here. We cannot report VRAM reductions." + ) + + @patch("torch.set_num_threads") + @patch("helpers.training.state_tracker.StateTracker.set_global_step") + @patch("helpers.training.state_tracker.StateTracker.set_args") + @patch("helpers.training.state_tracker.StateTracker.set_weight_dtype") + @patch("helpers.training.trainer.Trainer.set_model_family") + @patch("helpers.training.trainer.Trainer.init_noise_schedule") + @patch( + "argparse.ArgumentParser.parse_args", + return_value=MagicMock( + torch_num_threads=2, + train_batch_size=1, + weight_dtype=torch.float32, + optimizer="adamw_bf16", + max_train_steps=2, + num_train_epochs=0, + timestep_bias_portion=0, + metadata_update_interval=100, + gradient_accumulation_steps=1, + report_to="none", + output_dir="output_dir", + ), + ) + def test_misc_init( + self, + mock_argparse, + mock_init_noise_schedule, + mock_set_model_family, + mock_set_weight_dtype, + mock_set_args, + mock_set_global_step, + mock_set_num_threads, + ): + trainer = Trainer() + trainer._misc_init() + mock_set_num_threads.assert_called_with(2) + self.assertEqual( + trainer.state, + {"lr": 0.0, "global_step": 0, "global_resume_step": 0, "first_epoch": 1}, + ) + self.assertEqual(trainer.timesteps_buffer, []) + self.assertEqual(trainer.guidance_values_list, []) + self.assertEqual(trainer.train_loss, 0.0) + self.assertIsNone(trainer.bf) + self.assertIsNone(trainer.grad_norm) + self.assertEqual(trainer.extra_lr_scheduler_kwargs, {}) + mock_set_global_step.assert_called_with(0) + mock_set_args.assert_called_with(trainer.config) + mock_set_weight_dtype.assert_called_with(trainer.config.weight_dtype) + mock_set_model_family.assert_called() + mock_init_noise_schedule.assert_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch( + "helpers.training.trainer.load_scheduler_from_args", + return_value=(Mock(), "flow_matching_value", "noise_scheduler_value"), + ) + def test_init_noise_schedule( + self, mock_load_scheduler_from_args, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.config = Mock() + trainer.init_noise_schedule() + self.assertEqual(trainer.config.flow_matching, "flow_matching_value") + self.assertEqual(trainer.noise_scheduler, "noise_scheduler_value") + self.assertEqual(trainer.lr, 0.0) + + @patch("helpers.training.trainer.logger") + @patch( + "helpers.training.trainer.model_classes", {"full": ["sdxl", "sd3", "legacy"]} + ) + @patch( + "helpers.training.trainer.model_labels", + {"sdxl": "SDXL", "sd3": "SD3", "legacy": "Legacy"}, + ) + @patch("helpers.training.state_tracker.StateTracker") + def test_set_model_family_default(self, mock_state_tracker, mock_logger): + with patch("helpers.training.trainer.Trainer._misc_init"): + with patch("helpers.training.trainer.Trainer.parse_arguments"): + trainer = Trainer() + trainer.config = Mock(model_family=None) + trainer.config.pretrained_model_name_or_path = "some/path" + trainer.config.pretrained_vae_model_name_or_path = None + trainer.config.vae_path = None + trainer.config.text_encoder_path = None + trainer.config.text_encoder_subfolder = None + trainer.config.model_family = None + + with patch.object(trainer, "_set_model_paths") as mock_set_model_paths: + with patch( + "helpers.training.state_tracker.StateTracker.is_sdxl_refiner", + return_value=False, + ): + trainer.set_model_family() + self.assertEqual(trainer.config.model_type_label, "SDXL") + mock_logger.warning.assert_called() + mock_set_model_paths.assert_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + def test_set_model_family_invalid(self, mock_parse_args, mock_misc_init): + trainer = Trainer() + trainer.config = Mock(model_family="invalid_model_family") + trainer.config.pretrained_model_name_or_path = "some/path" + with self.assertRaises(ValueError) as context: + trainer.set_model_family() + self.assertIn( + "Invalid model family specified: invalid_model_family", + str(context.exception), + ) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.logger") + @patch("helpers.training.state_tracker.StateTracker") + def test_epoch_rollover( + self, mock_state_tracker, mock_logger, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.state = {"first_epoch": 1, "current_epoch": 1} + trainer.config = Mock( + num_train_epochs=5, + aspect_bucket_disable_rebuild=False, + lr_scheduler="cosine_with_restarts", + ) + trainer.extra_lr_scheduler_kwargs = {} + with patch( + "helpers.training.state_tracker.StateTracker.get_data_backends", + return_value={}, + ): + trainer._epoch_rollover(2) + self.assertEqual(trainer.state["current_epoch"], 2) + self.assertEqual(trainer.extra_lr_scheduler_kwargs["epoch"], 2) + + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + def test_epoch_rollover_same_epoch(self, mock_misc_init, mock_parse_args): + trainer = Trainer( + config={ + "--num_train_epochs": 0, + "--model_family": "pixart_sigma", + "--optimizer": "adamw_bf16", + "--pretrained_model_name_or_path": "some/path", + } + ) + trainer.state = {"first_epoch": 1, "current_epoch": 1} + trainer._epoch_rollover(1) + self.assertEqual(trainer.state["current_epoch"], 1) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.os.makedirs") + @patch("helpers.training.state_tracker.StateTracker.delete_cache_files") + def test_init_clear_backend_cache_preserve( + self, mock_delete_cache_files, mock_makedirs, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.config = Mock( + output_dir="/path/to/output", preserve_data_backend_cache=True + ) + trainer.init_clear_backend_cache() + mock_makedirs.assert_called_with("/path/to/output", exist_ok=True) + mock_delete_cache_files.assert_not_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.os.makedirs") + @patch("helpers.training.state_tracker.StateTracker.delete_cache_files") + def test_init_clear_backend_cache_delete( + self, mock_delete_cache_files, mock_makedirs, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.config = Mock( + output_dir="/path/to/output", preserve_data_backend_cache=False + ) + trainer.init_clear_backend_cache() + mock_makedirs.assert_called_with("/path/to/output", exist_ok=True) + mock_delete_cache_files.assert_called_with(preserve_data_backend_cache=False) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.huggingface_hub") + @patch("helpers.training.trainer.HubManager") + @patch("helpers.training.state_tracker.StateTracker") + @patch("accelerate.logging.MultiProcessAdapter.log") + def test_init_huggingface_hub( + self, + mock_logger, + mock_state_tracker, + mock_hub_manager_class, + mock_hf_hub, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + trainer.config = Mock(push_to_hub=True, huggingface_token="fake_token") + trainer.accelerator = Mock(is_main_process=True) + mock_hf_hub.whoami = Mock(return_value={"id": "fake_id", "name": "foobar"}) + trainer.init_huggingface_hub(access_token="fake_token") + mock_hf_hub.login.assert_called_with(token="fake_token") + mock_hub_manager_class.assert_called_with(config=trainer.config) + mock_hf_hub.whoami.assert_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.logger") + @patch("helpers.training.trainer.os.path.basename", return_value="checkpoint-100") + @patch( + "helpers.training.trainer.os.listdir", + return_value=["checkpoint-100", "checkpoint-200"], + ) + @patch( + "helpers.training.trainer.os.path.join", + side_effect=lambda *args: "/".join(args), + ) + @patch("helpers.training.trainer.os.path.exists", return_value=True) + @patch("helpers.training.trainer.Accelerator") + @patch("helpers.training.state_tracker.StateTracker") + def test_init_resume_checkpoint( + self, + mock_state_tracker, + mock_accelerator_class, + mock_path_exists, + mock_path_join, + mock_os_listdir, + mock_path_basename, + mock_logger, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + trainer.config = Mock( + output_dir="/path/to/output", + resume_from_checkpoint="latest", + total_steps_remaining_at_start=100, + global_resume_step=1, + num_train_epochs=0, + max_train_steps=100, + ) + trainer.accelerator = Mock(num_processes=1) + trainer.state = {"global_step": 0, "first_epoch": 1, "current_epoch": 1} + trainer.optimizer = Mock() + trainer.config.lr_scheduler = "constant" + trainer.config.learning_rate = 0.001 + trainer.config.is_schedulefree = False + trainer.config.overrode_max_train_steps = False + + # Mock lr_scheduler + lr_scheduler = Mock() + lr_scheduler.state_dict.return_value = {"base_lrs": [0.1], "_last_lr": [0.1]} + + with patch( + "helpers.training.state_tracker.StateTracker.get_data_backends", + return_value={}, + ): + with patch( + "helpers.training.state_tracker.StateTracker.get_global_step", + return_value=100, + ): + trainer.init_resume_checkpoint(lr_scheduler=lr_scheduler) + mock_logger.info.assert_called() + trainer.accelerator.load_state.assert_called_with( + "/path/to/output/checkpoint-200" + ) + + # Additional tests can be added for other methods as needed + + +if __name__ == "__main__": + unittest.main() From 558abad739139256e80e93b84ec512ad1d1847cc Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 14 Sep 2024 15:50:27 -0600 Subject: [PATCH 11/12] tests: do not raise error --- helpers/configuration/env_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helpers/configuration/env_file.py b/helpers/configuration/env_file.py index b34739f8..7a916a3a 100644 --- a/helpers/configuration/env_file.py +++ b/helpers/configuration/env_file.py @@ -133,7 +133,7 @@ def load_env(): print(f"[CONFIG.ENV] Loaded environment variables from {config_env_path}") else: - raise ValueError(f"Cannot find config file: {config_env_path}") + logger.error(f"Cannot find config file: {config_env_path}") return config_file_contents From 4dd81a2da3af4322e4639c6f0861a926805eccde Mon Sep 17 00:00:00 2001 From: bghira Date: Sat, 14 Sep 2024 15:54:40 -0600 Subject: [PATCH 12/12] tests: fix error with mixed_precision value on GH Actions --- tests/test_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 5093a16a..507844b9 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -129,6 +129,7 @@ def test_stats_memory_used_none( timestep_bias_portion=0, metadata_update_interval=100, gradient_accumulation_steps=1, + mixed_precision="bf16", report_to="none", output_dir="output_dir", ),