diff --git a/fine_tune.py b/fine_tune.py index a0350ce18..3c4a5a26b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -520,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/library/train_util.py b/library/train_util.py index 1a46f6a7d..c13bb68ee 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1890,7 +1890,7 @@ def __init__( subset.image_dir, False, None, - subset.caption_extension, + subset.caption_extension, subset.cache_info, subset.num_repeats, subset.shuffle_caption, @@ -3358,6 +3358,60 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): ) +# verify command line args for training +def verify_command_line_training_args(args: argparse.Namespace): + # if wandb is enabled, the command line is exposed to the public + # check whether sensitive options are included in the command line arguments + # if so, warn or inform the user to move them to the configuration file + # wandbが有効な場合、コマンドラインが公開される + # 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、 + # 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する + + wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb" + if not wandb_enabled: + return + + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + + for arg in sensitive_args: + if getattr(args, arg, None) is not None: + logger.warning( + f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file." + + f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。" + ) + + # if path is absolute, it may include sensitive information + for arg in sensitive_path_args: + if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)): + logger.info( + f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path." + + f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。" + ) + + if getattr(args, "config_file", None) is not None: + logger.info( + f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path." + + f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。" + ) + + # other sensitive options + if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public": + logger.info( + f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file." + + f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。" + ) + + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable diff --git a/sdxl_train.py b/sdxl_train.py index 816598e04..f6d277494 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -812,6 +812,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 9eaaa19f2..e880b57de 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -612,6 +612,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e55a58896..0ea64b824 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -580,6 +580,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d33239d92..83969bb1d 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -178,6 +178,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = SdxlNetworkTrainer() diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 257d181ad..5df739e28 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -131,6 +131,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = SdxlTextualInversionTrainer() diff --git a/train_controlnet.py b/train_controlnet.py index 0cb0405fd..90cac0410 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -617,6 +617,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_db.py b/train_db.py index 0a152f224..c3b7339f3 100644 --- a/train_db.py +++ b/train_db.py @@ -523,6 +523,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args) diff --git a/train_network.py b/train_network.py index 8fe98f126..fcf4cd9b6 100644 --- a/train_network.py +++ b/train_network.py @@ -1101,6 +1101,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = NetworkTrainer() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index e7083596f..02edf9525 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -806,6 +806,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) trainer = TextualInversionTrainer() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 861d48d1d..f0723f2a7 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -714,6 +714,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() + train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) train(args)