Skip to content

Commit

Permalink
add --log_config option to enable/disable output training config
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 19, 2024
1 parent 47187f7 commit c68baae
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 16 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
- Specify the learning rate and dim (rank) for each block.
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).

- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
- Some settings, such as API keys and directory specifications, are not output due to security issues.

- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
- It seems that the model file loading is faster in the WSL environment etc.
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
Expand Down Expand Up @@ -209,6 +212,9 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
- ブロックごとに学習率および dim (rank) を指定することができます。
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。

- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
- API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。

- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
- `sdxl_train.py``sdxl_train_network.py``sdxl_train_textual_inversion.py``sdxl_train_control_net_lllite.py` で使用可能です。
Expand Down
20 changes: 15 additions & 5 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
Expand Down Expand Up @@ -354,7 +358,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -368,7 +374,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -380,7 +388,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down Expand Up @@ -471,7 +481,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.end_training()

if is_main_process and (args.save_state or args.save_state_on_train_end):
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
16 changes: 13 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3180,6 +3180,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
)
parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")

parser.add_argument(
"--noise_offset",
Expand Down Expand Up @@ -3388,7 +3389,15 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)

def filter_sensitive_args(args: argparse.Namespace):

def get_sanitized_config_or_none(args: argparse.Namespace):
# if `--log_config` is enabled, return args for logging. if not, return None.
# when `--log_config is enabled, filter out sensitive values from args
# if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe

if not args.log_config:
return None

sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"pretrained_model_name_or_path",
Expand All @@ -3402,9 +3411,9 @@ def filter_sensitive_args(args: argparse.Namespace):
]
filtered_args = {}
for k, v in vars(args).items():
# filter out sensitive values
# filter out sensitive values and convert to string if necessary
if k not in sensitive_args + sensitive_path_args:
#Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
# Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
filtered_args[k] = v
# accelerate does not support lists
Expand All @@ -3416,6 +3425,7 @@ def filter_sensitive_args(args: argparse.Namespace):

return filtered_args


# verify command line args for training
def verify_command_line_training_args(args: argparse.Namespace):
# if wandb is enabled, the command line is exposed to the public
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def optimizer_hook(parameter: torch.Tensor):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)

# For --sample_at_first
sdxl_train_util.sample_images(
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)

# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def load_model_hook(models, input_dir):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)

loss_recorder = train_util.LossRecorder()
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def train(self, args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)

# function for saving/removing
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)

# function for saving/removing
Expand Down

0 comments on commit c68baae

Please sign in to comment.