diff --git a/helpers/configuration/env_file.py b/helpers/configuration/env_file.py index b34739f8..7a916a3a 100644 --- a/helpers/configuration/env_file.py +++ b/helpers/configuration/env_file.py @@ -133,7 +133,7 @@ def load_env(): print(f"[CONFIG.ENV] Loaded environment variables from {config_env_path}") else: - raise ValueError(f"Cannot find config file: {config_env_path}") + logger.error(f"Cannot find config file: {config_env_path}") return config_file_contents diff --git a/helpers/metadata/backends/base.py b/helpers/metadata/backends/base.py index 1665d6b0..e34eb8fb 100644 --- a/helpers/metadata/backends/base.py +++ b/helpers/metadata/backends/base.py @@ -261,7 +261,13 @@ def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = Fals if self.should_abort: logger.info("Aborting aspect bucket update.") return - while any(worker.is_alive() for worker in workers): + while ( + any(worker.is_alive() for worker in workers) + or not tqdm_queue.empty() + or not aspect_ratio_bucket_indices_queue.empty() + or not metadata_updates_queue.empty() + or not written_files_queue.empty() + ): current_time = time.time() while not tqdm_queue.empty(): pbar.update(tqdm_queue.get()) diff --git a/helpers/metadata/backends/discovery.py b/helpers/metadata/backends/discovery.py index 99481ffa..7e3959bf 100644 --- a/helpers/metadata/backends/discovery.py +++ b/helpers/metadata/backends/discovery.py @@ -122,11 +122,11 @@ def reload_cache(self, set_config: bool = True): dict: The cache data. """ # Query our DataBackend to see whether the cache file exists. - logger.info(f"Checking for cache file: {self.cache_file}") + logger.debug(f"Checking for cache file: {self.cache_file}") if self.data_backend.exists(self.cache_file): try: # Use our DataBackend to actually read the cache file. - logger.info("Pulling cache file from storage") + logger.debug("Pulling cache file from storage") cache_data_raw = self.data_backend.read(self.cache_file) cache_data = json.loads(cache_data_raw) except Exception as e: diff --git a/helpers/training/save_hooks.py b/helpers/training/save_hooks.py index 4819cfdc..35197fa5 100644 --- a/helpers/training/save_hooks.py +++ b/helpers/training/save_hooks.py @@ -168,8 +168,8 @@ def __init__( if args.controlnet: self.denoiser_class = ControlNetModel self.denoiser_subdir = "controlnet" - logger.info(f"Denoiser class set to: {self.denoiser_class.__name__}.") - logger.info(f"Pipeline class set to: {self.pipeline_class.__name__}.") + logger.debug(f"Denoiser class set to: {self.denoiser_class.__name__}.") + logger.debug(f"Pipeline class set to: {self.pipeline_class.__name__}.") self.ema_model_cls = None self.ema_model_subdir = None diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index c6afca53..424dec8f 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -158,6 +158,7 @@ def __init__(self, config: dict = None): self.lycoris_wrapped_network = None self.lycoris_config = None self.lr_scheduler = None + self.webhook_handler = None self.should_abort = False def _config_to_obj(self, config): @@ -257,7 +258,6 @@ def run(self): raise e - def _initialize_components_with_signal_check(self, initializers): """ Runs a list of initializer functions with signal checks after each. @@ -921,26 +921,40 @@ def init_post_load_freeze(self): if self.unet is not None: logger.info("Applying BitFit freezing strategy to the U-net.") - self.unet = apply_bitfit_freezing(self.unet, self.config) + self.unet = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.unet), self.config + ) if self.transformer is not None: logger.warning( "Training DiT models with BitFit is not yet tested, and unexpected results may occur." ) - self.transformer = apply_bitfit_freezing(self.transformer, self.config) + self.transformer = apply_bitfit_freezing( + unwrap_model(self.accelerator, self.transformer), self.config + ) if self.config.gradient_checkpointing: if self.unet is not None: - self.unet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.unet + ).enable_gradient_checkpointing() if self.transformer is not None and self.config.model_family != "smoldit": - self.transformer.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.transformer + ).enable_gradient_checkpointing() if self.config.controlnet: - self.controlnet.enable_gradient_checkpointing() + unwrap_model( + self.accelerator, self.controlnet + ).enable_gradient_checkpointing() if ( hasattr(self.config, "train_text_encoder") and self.config.train_text_encoder ): - self.text_encoder_1.gradient_checkpointing_enable() - self.text_encoder_2.gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_1 + ).gradient_checkpointing_enable() + unwrap_model( + self.accelerator, self.text_encoder_2 + ).gradient_checkpointing_enable() def _recalculate_training_steps(self): # Scheduler and math around the number of training steps. diff --git a/tests/test_metadata_backend.py b/tests/test_metadata_backend.py index 0de318d2..16584329 100644 --- a/tests/test_metadata_backend.py +++ b/tests/test_metadata_backend.py @@ -22,7 +22,7 @@ def setUp(self): self.image_path_str = "test_image.jpg" self.instance_data_dir = "/some/fake/path" - self.cache_file = "/some/fake/cache.json" + self.cache_file = "/some/fake/cache" self.metadata_file = "/some/fake/metadata.json" StateTracker.set_args(MagicMock()) # Overload cache file with json: diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 00000000..507844b9 --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,384 @@ +# test_trainer.py + +import unittest +from unittest.mock import Mock, patch, MagicMock +import torch, os + +os.environ["SIMPLETUNER_LOG_LEVEL"] = "CRITICAL" +from helpers.training.trainer import Trainer + + +class TestTrainer(unittest.TestCase): + @patch("helpers.training.trainer.load_config") + @patch("helpers.training.trainer.safety_check") + @patch( + "helpers.training.trainer.load_scheduler_from_args", + return_value=(Mock(), None, Mock()), + ) + @patch("helpers.training.state_tracker.StateTracker") + @patch( + "helpers.training.state_tracker.StateTracker.set_model_family", + return_value=True, + ) + @patch("torch.set_num_threads") + @patch("helpers.training.trainer.Accelerator") + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + def test_config_to_obj( + self, + mock_misc_init, + mock_parse_args, + mock_accelerator, + mock_set_num_threads, + mock_set_model_family, + mock_state_tracker, + mock_load_scheduler_from_args, + mock_safety_check, + mock_load_config, + ): + trainer = Trainer() + config_dict = {"a": 1, "b": 2} + config_obj = trainer._config_to_obj(config_dict) + self.assertEqual(config_obj.a, 1) + self.assertEqual(config_obj.b, 2) + + config_none = trainer._config_to_obj(None) + self.assertIsNone(config_none) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.set_seed") + def test_init_seed_with_value(self, mock_set_seed, mock_parse_args, mock_misc_init): + trainer = Trainer() + trainer.config = Mock(seed=42, seed_for_each_device=False) + trainer.init_seed() + mock_set_seed.assert_called_with(42, False) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.set_seed") + def test_init_seed_none(self, mock_set_seed, mock_parse_args, mock_misc_init): + trainer = Trainer() + trainer.config = Mock(seed=None, seed_for_each_device=False) + trainer.init_seed() + mock_set_seed.assert_not_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("torch.cuda.is_available", return_value=True) + @patch("torch.cuda.memory_allocated", return_value=1024**3) + def test_stats_memory_used_cuda( + self, mock_memory_allocated, mock_is_available, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + memory_used = trainer.stats_memory_used() + self.assertEqual(memory_used, 1.0) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("torch.cuda.is_available", return_value=False) + @patch("torch.backends.mps.is_available", return_value=True) + @patch("torch.mps.current_allocated_memory", return_value=1024**3) + def test_stats_memory_used_mps( + self, + mock_current_allocated_memory, + mock_mps_is_available, + mock_cuda_is_available, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + memory_used = trainer.stats_memory_used() + self.assertEqual(memory_used, 1.0) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("torch.cuda.is_available", return_value=False) + @patch("torch.backends.mps.is_available", return_value=False) + @patch("helpers.training.trainer.logger") + def test_stats_memory_used_none( + self, + mock_logger, + mock_mps_is_available, + mock_cuda_is_available, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + memory_used = trainer.stats_memory_used() + self.assertEqual(memory_used, 0) + mock_logger.warning.assert_called_with( + "CUDA, ROCm, or Apple MPS not detected here. We cannot report VRAM reductions." + ) + + @patch("torch.set_num_threads") + @patch("helpers.training.state_tracker.StateTracker.set_global_step") + @patch("helpers.training.state_tracker.StateTracker.set_args") + @patch("helpers.training.state_tracker.StateTracker.set_weight_dtype") + @patch("helpers.training.trainer.Trainer.set_model_family") + @patch("helpers.training.trainer.Trainer.init_noise_schedule") + @patch( + "argparse.ArgumentParser.parse_args", + return_value=MagicMock( + torch_num_threads=2, + train_batch_size=1, + weight_dtype=torch.float32, + optimizer="adamw_bf16", + max_train_steps=2, + num_train_epochs=0, + timestep_bias_portion=0, + metadata_update_interval=100, + gradient_accumulation_steps=1, + mixed_precision="bf16", + report_to="none", + output_dir="output_dir", + ), + ) + def test_misc_init( + self, + mock_argparse, + mock_init_noise_schedule, + mock_set_model_family, + mock_set_weight_dtype, + mock_set_args, + mock_set_global_step, + mock_set_num_threads, + ): + trainer = Trainer() + trainer._misc_init() + mock_set_num_threads.assert_called_with(2) + self.assertEqual( + trainer.state, + {"lr": 0.0, "global_step": 0, "global_resume_step": 0, "first_epoch": 1}, + ) + self.assertEqual(trainer.timesteps_buffer, []) + self.assertEqual(trainer.guidance_values_list, []) + self.assertEqual(trainer.train_loss, 0.0) + self.assertIsNone(trainer.bf) + self.assertIsNone(trainer.grad_norm) + self.assertEqual(trainer.extra_lr_scheduler_kwargs, {}) + mock_set_global_step.assert_called_with(0) + mock_set_args.assert_called_with(trainer.config) + mock_set_weight_dtype.assert_called_with(trainer.config.weight_dtype) + mock_set_model_family.assert_called() + mock_init_noise_schedule.assert_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch( + "helpers.training.trainer.load_scheduler_from_args", + return_value=(Mock(), "flow_matching_value", "noise_scheduler_value"), + ) + def test_init_noise_schedule( + self, mock_load_scheduler_from_args, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.config = Mock() + trainer.init_noise_schedule() + self.assertEqual(trainer.config.flow_matching, "flow_matching_value") + self.assertEqual(trainer.noise_scheduler, "noise_scheduler_value") + self.assertEqual(trainer.lr, 0.0) + + @patch("helpers.training.trainer.logger") + @patch( + "helpers.training.trainer.model_classes", {"full": ["sdxl", "sd3", "legacy"]} + ) + @patch( + "helpers.training.trainer.model_labels", + {"sdxl": "SDXL", "sd3": "SD3", "legacy": "Legacy"}, + ) + @patch("helpers.training.state_tracker.StateTracker") + def test_set_model_family_default(self, mock_state_tracker, mock_logger): + with patch("helpers.training.trainer.Trainer._misc_init"): + with patch("helpers.training.trainer.Trainer.parse_arguments"): + trainer = Trainer() + trainer.config = Mock(model_family=None) + trainer.config.pretrained_model_name_or_path = "some/path" + trainer.config.pretrained_vae_model_name_or_path = None + trainer.config.vae_path = None + trainer.config.text_encoder_path = None + trainer.config.text_encoder_subfolder = None + trainer.config.model_family = None + + with patch.object(trainer, "_set_model_paths") as mock_set_model_paths: + with patch( + "helpers.training.state_tracker.StateTracker.is_sdxl_refiner", + return_value=False, + ): + trainer.set_model_family() + self.assertEqual(trainer.config.model_type_label, "SDXL") + mock_logger.warning.assert_called() + mock_set_model_paths.assert_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + def test_set_model_family_invalid(self, mock_parse_args, mock_misc_init): + trainer = Trainer() + trainer.config = Mock(model_family="invalid_model_family") + trainer.config.pretrained_model_name_or_path = "some/path" + with self.assertRaises(ValueError) as context: + trainer.set_model_family() + self.assertIn( + "Invalid model family specified: invalid_model_family", + str(context.exception), + ) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.logger") + @patch("helpers.training.state_tracker.StateTracker") + def test_epoch_rollover( + self, mock_state_tracker, mock_logger, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.state = {"first_epoch": 1, "current_epoch": 1} + trainer.config = Mock( + num_train_epochs=5, + aspect_bucket_disable_rebuild=False, + lr_scheduler="cosine_with_restarts", + ) + trainer.extra_lr_scheduler_kwargs = {} + with patch( + "helpers.training.state_tracker.StateTracker.get_data_backends", + return_value={}, + ): + trainer._epoch_rollover(2) + self.assertEqual(trainer.state["current_epoch"], 2) + self.assertEqual(trainer.extra_lr_scheduler_kwargs["epoch"], 2) + + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + def test_epoch_rollover_same_epoch(self, mock_misc_init, mock_parse_args): + trainer = Trainer( + config={ + "--num_train_epochs": 0, + "--model_family": "pixart_sigma", + "--optimizer": "adamw_bf16", + "--pretrained_model_name_or_path": "some/path", + } + ) + trainer.state = {"first_epoch": 1, "current_epoch": 1} + trainer._epoch_rollover(1) + self.assertEqual(trainer.state["current_epoch"], 1) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.os.makedirs") + @patch("helpers.training.state_tracker.StateTracker.delete_cache_files") + def test_init_clear_backend_cache_preserve( + self, mock_delete_cache_files, mock_makedirs, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.config = Mock( + output_dir="/path/to/output", preserve_data_backend_cache=True + ) + trainer.init_clear_backend_cache() + mock_makedirs.assert_called_with("/path/to/output", exist_ok=True) + mock_delete_cache_files.assert_not_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.os.makedirs") + @patch("helpers.training.state_tracker.StateTracker.delete_cache_files") + def test_init_clear_backend_cache_delete( + self, mock_delete_cache_files, mock_makedirs, mock_parse_args, mock_misc_init + ): + trainer = Trainer() + trainer.config = Mock( + output_dir="/path/to/output", preserve_data_backend_cache=False + ) + trainer.init_clear_backend_cache() + mock_makedirs.assert_called_with("/path/to/output", exist_ok=True) + mock_delete_cache_files.assert_called_with(preserve_data_backend_cache=False) + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.huggingface_hub") + @patch("helpers.training.trainer.HubManager") + @patch("helpers.training.state_tracker.StateTracker") + @patch("accelerate.logging.MultiProcessAdapter.log") + def test_init_huggingface_hub( + self, + mock_logger, + mock_state_tracker, + mock_hub_manager_class, + mock_hf_hub, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + trainer.config = Mock(push_to_hub=True, huggingface_token="fake_token") + trainer.accelerator = Mock(is_main_process=True) + mock_hf_hub.whoami = Mock(return_value={"id": "fake_id", "name": "foobar"}) + trainer.init_huggingface_hub(access_token="fake_token") + mock_hf_hub.login.assert_called_with(token="fake_token") + mock_hub_manager_class.assert_called_with(config=trainer.config) + mock_hf_hub.whoami.assert_called() + + @patch("helpers.training.trainer.Trainer._misc_init", return_value=Mock()) + @patch("helpers.training.trainer.Trainer.parse_arguments", return_value=Mock()) + @patch("helpers.training.trainer.logger") + @patch("helpers.training.trainer.os.path.basename", return_value="checkpoint-100") + @patch( + "helpers.training.trainer.os.listdir", + return_value=["checkpoint-100", "checkpoint-200"], + ) + @patch( + "helpers.training.trainer.os.path.join", + side_effect=lambda *args: "/".join(args), + ) + @patch("helpers.training.trainer.os.path.exists", return_value=True) + @patch("helpers.training.trainer.Accelerator") + @patch("helpers.training.state_tracker.StateTracker") + def test_init_resume_checkpoint( + self, + mock_state_tracker, + mock_accelerator_class, + mock_path_exists, + mock_path_join, + mock_os_listdir, + mock_path_basename, + mock_logger, + mock_parse_args, + mock_misc_init, + ): + trainer = Trainer() + trainer.config = Mock( + output_dir="/path/to/output", + resume_from_checkpoint="latest", + total_steps_remaining_at_start=100, + global_resume_step=1, + num_train_epochs=0, + max_train_steps=100, + ) + trainer.accelerator = Mock(num_processes=1) + trainer.state = {"global_step": 0, "first_epoch": 1, "current_epoch": 1} + trainer.optimizer = Mock() + trainer.config.lr_scheduler = "constant" + trainer.config.learning_rate = 0.001 + trainer.config.is_schedulefree = False + trainer.config.overrode_max_train_steps = False + + # Mock lr_scheduler + lr_scheduler = Mock() + lr_scheduler.state_dict.return_value = {"base_lrs": [0.1], "_last_lr": [0.1]} + + with patch( + "helpers.training.state_tracker.StateTracker.get_data_backends", + return_value={}, + ): + with patch( + "helpers.training.state_tracker.StateTracker.get_global_step", + return_value=100, + ): + trainer.init_resume_checkpoint(lr_scheduler=lr_scheduler) + mock_logger.info.assert_called() + trainer.accelerator.load_state.assert_called_with( + "/path/to/output/checkpoint-200" + ) + + # Additional tests can be added for other methods as needed + + +if __name__ == "__main__": + unittest.main() diff --git a/train.py b/train.py index cb63532b..b3c72a18 100644 --- a/train.py +++ b/train.py @@ -8,8 +8,6 @@ logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) if __name__ == "__main__": - global bf - bf = None trainer = None try: import multiprocessing @@ -64,4 +62,4 @@ print(e) print(traceback.format_exc()) if trainer is not None and trainer.bf is not None: - bf.stop_fetching() + trainer.bf.stop_fetching()