Skip to content

Commit

Permalink
Merge branch 'kohya-ss:main' into min-SNR
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova authored Mar 21, 2023
2 parents a265225 + aee343a commit 795a6bd
Show file tree
Hide file tree
Showing 25 changed files with 253 additions and 82 deletions.
44 changes: 15 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,35 +127,21 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

## Change History


- 19 Mar. 2023, 2023/3/19:
- Add a function to load training config with `.toml` to each training script. Thanks to Linaqruf for this great contribution!
- Specify `.toml` file with `--config_file`. `.toml` file has `key=value` entries. Keys are same as command line options. See [#241](https://github.com/kohya-ss/sd-scripts/pull/241) for details.
- All sub-sections are combined to a single dictionary (the section names are ignored.)
- Omitted arguments are the default values for command line arguments.
- Command line args override the arguments in `.toml`.
- With `--output_config` option, you can output current command line options to the `.toml` specified with`--config_file`. Please use as a template.
- Add `--lr_scheduler_type` and `--lr_scheduler_args` arguments for custom LR scheduler to each training script. Thanks to Isotr0py! [#271](https://github.com/kohya-ss/sd-scripts/pull/271)
- Same as the optimizer.
- Add sample image generation with weight and no length limit. Thanks to mio2333! [#288](https://github.com/kohya-ss/sd-scripts/pull/288)
- `( )`, `(xxxx:1.2)` and `[ ]` can be used.
- Fix exception on training model in diffusers format with `train_network.py` Thanks to orenwang! [#290](https://github.com/kohya-ss/sd-scripts/pull/290)

- 各学習スクリプトでコマンドライン引数の代わりに`.toml` ファイルで引数を指定できるようになりました。Linaqruf氏の多大な貢献に感謝します。
- `--config_file``.toml` ファイルを指定してください。ファイルは `key=value` 形式の行で指定し、key はコマンドラインオプションと同じです。詳細は [#241](https://github.com/kohya-ss/sd-scripts/pull/241) をご覧ください。
- ファイル内のサブセクションはすべて無視されます。
- 省略した引数はコマンドライン引数のデフォルト値になります。
- コマンドライン引数で `.toml` の設定を上書きできます。
- `--output_config` オプションを指定すると、現在のコマンドライン引数を`--config_file` オプションで指定した `.toml` ファイルに出力します。ひな形としてご利用ください。
- 任意のスケジューラを使うための `--lr_scheduler_type``--lr_scheduler_args` オプションを各学習スクリプトに追加しました。Isotr0py氏に感謝します。 [#271](https://github.com/kohya-ss/sd-scripts/pull/271)
- 任意のオプティマイザ指定と同じ形式です。
- 学習中のサンプル画像出力でプロンプトの重みづけができるようになりました。また長さ制限も緩和されています。mio2333氏に感謝します。 [#288](https://github.com/kohya-ss/sd-scripts/pull/288)
- `( )``(xxxx:1.2)``[ ]` が使えます。
- `train_network.py` でローカルのDiffusersモデルを指定した時のエラーを修正しました。orenwang氏に感謝します。 [#290](https://github.com/kohya-ss/sd-scripts/pull/290)

- 11 Mar. 2023, 2023/3/11:
- Fix `svd_merge_lora.py` causes an error about the device.
- `svd_merge_lora.py` でデバイス関連のエラーが発生する不具合を修正しました。
- 21 Mar. 2023, 2023/3/21:
- Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls.
- Please start with`2` or `4` depending on the size of VRAM.
- Fix a number of training steps with `--gradient_accumulation_steps` and `--max_train_epochs`. Thanks to tsukimiya!
- Extract parser setup to external scripts. Thanks to robertsmieja!
- Fix an issue without `.npz` and with `--full_path` in training.
- Support extensions with upper cases for images for not Windows environment.
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
- latentsのキャッシュを高速化する`--vae_batch_size` オプションを各学習スクリプトに追加しました。VAE呼び出しをバッチ化します。
-VRAMサイズに応じて、`2``4` 程度から試してください。
- `--gradient_accumulation_steps``--max_train_epochs` を指定した時、当該のepochで学習が止まらない不具合を修正しました。tsukimiya氏に感謝します。
- 外部のスクリプト用に引数parserの構築が関数化されました。robertsmieja氏に感謝します。
- 学習時、`--full_path` 指定時に `.npz` が存在しない場合の不具合を解消しました。
- Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。
- `resize_lora.py` を dynamic rank (rankが各LoRAモジュールで異なる場合、`conv_dim``network_dim` と異なる場合も含む)の時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。


- Sample image generation:
Expand Down
14 changes: 10 additions & 4 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
Expand Down Expand Up @@ -194,7 +194,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")

# lr schedulerを用意する
Expand Down Expand Up @@ -240,7 +240,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")

progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
Expand Down Expand Up @@ -387,7 +387,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
print("model saved.")


if __name__ == "__main__":
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

train_util.add_sd_models_arguments(parser)
Expand All @@ -400,6 +400,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")

return parser


if __name__ == "__main__":
parser = setup_parser()

args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)

Expand Down
8 changes: 7 additions & 1 deletion finetune/clean_captions_and_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,19 @@ def main(args):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode")

return parser


if __name__ == '__main__':
parser = setup_parser()

args, unknown = parser.parse_known_args()
if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
Expand Down
8 changes: 7 additions & 1 deletion finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def run_batch(path_imgs):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
Expand All @@ -153,6 +153,12 @@ def run_batch(path_imgs):
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
parser.add_argument("--debug", action="store_true", help="debug mode")

return parser


if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()

# スペルミスしていたオプションを復元する
Expand Down
8 changes: 7 additions & 1 deletion finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def run_batch(path_imgs):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
Expand All @@ -141,5 +141,11 @@ def run_batch(path_imgs):
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")

return parser


if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()
main(args)
8 changes: 7 additions & 1 deletion finetune/merge_captions_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def main(args):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
Expand All @@ -61,6 +61,12 @@ def main(args):
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")

return parser


if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()

# スペルミスしていたオプションを復元する
Expand Down
8 changes: 7 additions & 1 deletion finetune/merge_dd_tags_to_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main(args):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
Expand All @@ -61,5 +61,11 @@ def main(args):
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")

return parser


if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()
main(args)
8 changes: 7 additions & 1 deletion finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def process_batch(is_last):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
Expand Down Expand Up @@ -257,5 +257,11 @@ def process_batch(is_last):
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")

return parser


if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()
main(args)
8 changes: 7 additions & 1 deletion finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def run_batch(path_imgs):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
Expand All @@ -191,6 +191,12 @@ def run_batch(path_imgs):
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode")

return parser


if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()

# スペルミスしていたオプションを復元する
Expand Down
8 changes: 7 additions & 1 deletion gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2690,7 +2690,7 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
print("done!")


if __name__ == '__main__':
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

parser.add_argument("--v2", action='store_true', help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
Expand Down Expand Up @@ -2786,5 +2786,11 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')

return parser


if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()
main(args)
69 changes: 53 additions & 16 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@

# region dataset

IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]


class ImageInfo:
Expand Down Expand Up @@ -675,10 +674,19 @@ def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_s
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])

def cache_latents(self, vae):
# TODO ここを高速化したい
def cache_latents(self, vae, vae_batch_size=1):
# ちょっと速くした
print("caching latents.")
for info in tqdm(self.image_data.values()):

image_infos = list(self.image_data.values())

# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])

# split by resolution
batches = []
batch = []
for info in image_infos:
subset = self.image_to_subset[info.image_key]

if info.latents_npz is not None:
Expand All @@ -689,18 +697,42 @@ def cache_latents(self, vae):
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue

image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []

batch.append(info)

img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []

if len(batch) > 0:
batches.append(batch)

# iterate batches
for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = []
for info in batch:
image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
image = self.image_transforms(image)
images.append(image)

img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)

latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents = latent

if subset.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents_flipped = latent

def get_image_size(self, image_path):
image = Image.open(image_path)
Expand Down Expand Up @@ -1197,6 +1229,10 @@ def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
npz_file_flip = None
return npz_file_norm, npz_file_flip

# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None, None

# image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
Expand Down Expand Up @@ -1237,10 +1273,10 @@ def add_replacement(self, str_from, str_to):
# for dataset in self.datasets:
# dataset.make_buckets()

def cache_latents(self, vae):
def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_latents(vae)
dataset.cache_latents(vae, vae_batch_size)

def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
Expand Down Expand Up @@ -1986,6 +2022,7 @@ def add_dataset_arguments(
action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)",
)
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
)
Expand Down
Loading

0 comments on commit 795a6bd

Please sign in to comment.