diff --git a/README.md b/README.md index 5d035eb6f..cd7744598 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). +- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! + - Some settings, such as API keys and directory specifications, are not output due to security issues. + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -209,6 +212,9 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 +- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 + - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 diff --git a/fine_tune.py b/fine_tune.py index 77a1a4f30..d865cd2de 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -310,7 +310,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -354,7 +358,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -368,7 +374,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -380,7 +388,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -471,7 +481,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if is_main_process and (args.save_state or args.save_state_on_train_end): + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/library/train_util.py b/library/train_util.py index 84764263e..410471470 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3180,6 +3180,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") parser.add_argument( "--noise_offset", @@ -3388,7 +3389,15 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) -def filter_sensitive_args(args: argparse.Namespace): + +def get_sanitized_config_or_none(args: argparse.Namespace): + # if `--log_config` is enabled, return args for logging. if not, return None. + # when `--log_config is enabled, filter out sensitive values from args + # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe + + if not args.log_config: + return None + sensitive_args = ["wandb_api_key", "huggingface_token"] sensitive_path_args = [ "pretrained_model_name_or_path", @@ -3402,9 +3411,9 @@ def filter_sensitive_args(args: argparse.Namespace): ] filtered_args = {} for k, v in vars(args).items(): - # filter out sensitive values + # filter out sensitive values and convert to string if necessary if k not in sensitive_args + sensitive_path_args: - #Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): filtered_args[k] = v # accelerate does not support lists @@ -3416,6 +3425,7 @@ def filter_sensitive_args(args: argparse.Namespace): return filtered_args + # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): # if wandb is enabled, the command line is exposed to the public diff --git a/sdxl_train.py b/sdxl_train.py index 4c4e38721..11f9892a3 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -589,7 +589,7 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) # For --sample_at_first sdxl_train_util.sample_images( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index b141965fa..301310901 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -354,7 +354,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 9490cf6f2..292a0463a 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,7 +324,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_controlnet.py b/train_controlnet.py index 793f79c7d..9994dd99c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -344,7 +344,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_db.py b/train_db.py index 4f9018293..a5408cd3d 100644 --- a/train_db.py +++ b/train_db.py @@ -290,7 +290,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/train_network.py b/train_network.py index 401a1c70e..38e4888e8 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,7 @@ def load_model_hook(models, input_dir): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 56a387391..184607d1d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -510,7 +510,7 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 691785239..8eed00fa1 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing