Skip to content

Commit 645c974

Browse files
KohakuBlueleafkohya-ssshirayu
authored
Dev (#10)
* Final implementation * Skip the final 1 step * fix alpha mask without disk cache closes kohya-ss#1351, ref kohya-ss#1339 * update for corner cases * Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close kohya-ss#1307) * set static graph flag when DDP ref kohya-ss#1363 * make forward/backward pathes same ref kohya-ss#1363 * update README * add grad_hook after restore state closes kohya-ss#1344 * fix to work cache_latents/text_encoder_outputs * show file name if error in load_image ref kohya-ss#1385 --------- Co-authored-by: Kohya S <[email protected]> Co-authored-by: Kohya S <[email protected]> Co-authored-by: Yuta Hayashibe <[email protected]>
1 parent 7ffc83a commit 645c974

11 files changed

+196
-53
lines changed

.github/workflows/typos.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ jobs:
1818
- uses: actions/checkout@v4
1919

2020
- name: typos-action
21-
uses: crate-ci/typos@v1.19.0
21+
uses: crate-ci/typos@v1.21.0

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
178178

179179
- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!
180180

181+
- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
182+
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
183+
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
184+
- If `--resume` is specified, the step saved in the state is used.
185+
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).
186+
181187
- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
182188
- It seems that the model file loading is faster in the WSL environment etc.
183189
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
@@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
235241

236242
- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。
237243

244+
- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
245+
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
246+
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
247+
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
248+
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。
249+
238250
- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
239251
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
240252
- `sdxl_train.py``sdxl_train_network.py``sdxl_train_textual_inversion.py``sdxl_train_control_net_lllite.py` で使用可能です。
@@ -253,6 +265,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します
253265

254266
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
255267

268+
### Jun 23, 2024 / 2024-06-23:
269+
270+
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
271+
272+
- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
273+
256274
### Apr 7, 2024 / 2024-04-07: v0.8.7
257275

258276
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.

_typos.toml

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
33

44
[default.extend-identifiers]
5+
ddPn08="ddPn08"
56

67
[default.extend-words]
78
NIN="NIN"
@@ -27,6 +28,7 @@ rik="rik"
2728
koo="koo"
2829
yos="yos"
2930
wn="wn"
31+
hime="hime"
3032

3133

3234
[files]

library/ipex/attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# pylint: disable=protected-access, missing-function-docstring, line-too-long
77

8-
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
8+
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
99

1010
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
1111
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))

library/train_util.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,16 @@ def set_caching_mode(self, mode):
657657

658658
def set_current_epoch(self, epoch):
659659
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
660-
self.shuffle_buckets()
661-
self.current_epoch = epoch
660+
if epoch > self.current_epoch:
661+
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
662+
num_epochs = epoch - self.current_epoch
663+
for _ in range(num_epochs):
664+
self.current_epoch += 1
665+
self.shuffle_buckets()
666+
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
667+
else:
668+
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
669+
self.current_epoch = epoch
662670

663671
def set_current_step(self, step):
664672
self.current_step = step
@@ -1265,7 +1273,8 @@ def __getitem__(self, index):
12651273
if subset.alpha_mask:
12661274
if img.shape[2] == 4:
12671275
alpha_mask = img[:, :, 3] # [H,W]
1268-
alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1
1276+
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
1277+
alpha_mask = torch.FloatTensor(alpha_mask)
12691278
else:
12701279
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
12711280
else:
@@ -2211,7 +2220,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
22112220
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
22122221
def load_latents_from_disk(
22132222
npz_path,
2214-
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
2223+
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
22152224
npz = np.load(npz_path)
22162225
if "latents" not in npz:
22172226
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
@@ -2229,7 +2238,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli
22292238
if flipped_latents_tensor is not None:
22302239
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
22312240
if alpha_mask is not None:
2232-
kwargs["alpha_mask"] = alpha_mask # ndarray
2241+
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
22332242
np.savez(
22342243
npz_path,
22352244
latents=latents_tensor.float().cpu().numpy(),
@@ -2425,16 +2434,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
24252434
return train_dataset_group
24262435

24272436

2428-
def load_image(image_path, alpha=False):
2429-
image = Image.open(image_path)
2430-
if alpha:
2431-
if not image.mode == "RGBA":
2432-
image = image.convert("RGBA")
2433-
else:
2434-
if not image.mode == "RGB":
2435-
image = image.convert("RGB")
2436-
img = np.array(image, np.uint8)
2437-
return img
2437+
def load_image(image_path, alpha=False):
2438+
try:
2439+
with Image.open(image_path) as image:
2440+
if alpha:
2441+
if not image.mode == "RGBA":
2442+
image = image.convert("RGBA")
2443+
else:
2444+
if not image.mode == "RGB":
2445+
image = image.convert("RGB")
2446+
img = np.array(image, np.uint8)
2447+
return img
2448+
except (IOError, OSError) as e:
2449+
logger.error(f"Error loading file: {image_path}")
2450+
raise e
24382451

24392452

24402453
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
@@ -2496,8 +2509,9 @@ def cache_batch_latents(
24962509
if image.shape[2] == 4:
24972510
alpha_mask = image[:, :, 3] # [H,W]
24982511
alpha_mask = alpha_mask.astype(np.float32) / 255.0
2512+
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
24992513
else:
2500-
alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32)
2514+
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
25012515
else:
25022516
alpha_mask = None
25032517
alpha_masks.append(alpha_mask)
@@ -5554,6 +5568,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None:
55545568
if epoch == 0:
55555569
self.loss_list.append(loss)
55565570
else:
5571+
while len(self.loss_list) <= step:
5572+
self.loss_list.append(0.0)
55575573
self.loss_total -= self.loss_list[step]
55585574
self.loss_list[step] = loss
55595575
self.loss_total += loss

networks/control_net_lllite_for_train.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import torch
88
from library import sdxl_original_unet
99
from library.utils import setup_logging
10+
1011
setup_logging()
1112
import logging
13+
1214
logger = logging.getLogger(__name__)
1315

1416
# input_blocksに適用するかどうか / if True, input_blocks are not applied
@@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie
103105
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
104106

105107
self.cond_image = None
106-
self.cond_emb = None
107108

108109
def set_cond_image(self, cond_image):
109110
self.cond_image = cond_image
110-
self.cond_emb = None
111111

112112
def forward(self, x):
113113
if not self.enabled:
114114
return super().forward(x)
115115

116-
if self.cond_emb is None:
117-
self.cond_emb = self.lllite_conditioning1(self.cond_image)
118-
cx = self.cond_emb
116+
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
119117

120118
# reshape / b,c,h,w -> b,h*w,c
121119
n, c, h, w = cx.shape
@@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None):
159157
if not self.enabled:
160158
return super().forward(x)
161159

162-
if self.cond_emb is None:
163-
self.cond_emb = self.lllite_conditioning1(self.cond_image)
164-
cx = self.cond_emb
160+
cx = self.lllite_conditioning1(self.cond_image)
165161

166162
cx = torch.cat([cx, self.down(x)], dim=1)
167163
cx = self.mid(cx)

sdxl_train.py

+25-21
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
481481
text_encoder2 = accelerator.prepare(text_encoder2)
482482
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
483483

484+
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
485+
if args.cache_text_encoder_outputs:
486+
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
487+
text_encoder1.to("cpu", dtype=torch.float32)
488+
text_encoder2.to("cpu", dtype=torch.float32)
489+
clean_memory_on_device(accelerator.device)
490+
else:
491+
# make sure Text Encoders are on GPU
492+
text_encoder1.to(accelerator.device)
493+
text_encoder2.to(accelerator.device)
494+
495+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
496+
if args.full_fp16:
497+
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
498+
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
499+
train_util.patch_accelerator_for_fp16_training(accelerator)
500+
501+
# resumeする
502+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
503+
484504
if args.fused_backward_pass:
485505
# use fused optimizer for backward pass: other optimizers will be supported in the future
486506
import library.adafactor_fused
@@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor):
532552
parameter_optimizer_map[parameter] = opt_idx
533553
num_parameters_per_group[opt_idx] += 1
534554

535-
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
536-
if args.cache_text_encoder_outputs:
537-
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
538-
text_encoder1.to("cpu", dtype=torch.float32)
539-
text_encoder2.to("cpu", dtype=torch.float32)
540-
clean_memory_on_device(accelerator.device)
541-
else:
542-
# make sure Text Encoders are on GPU
543-
text_encoder1.to(accelerator.device)
544-
text_encoder2.to(accelerator.device)
545-
546-
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
547-
if args.full_fp16:
548-
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
549-
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
550-
train_util.patch_accelerator_for_fp16_training(accelerator)
551-
552-
# resumeする
553-
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
554-
555555
# epoch数を計算する
556556
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
557557
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor):
589589
init_kwargs["wandb"] = {"name": args.wandb_run_name}
590590
if args.log_tracker_config is not None:
591591
init_kwargs = toml.load(args.log_tracker_config)
592-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
592+
accelerator.init_trackers(
593+
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
594+
config=train_util.get_sanitized_config_or_none(args),
595+
init_kwargs=init_kwargs,
596+
)
593597

594598
# For --sample_at_first
595599
sdxl_train_util.sample_images(

sdxl_train_control_net_lllite.py

+3
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def train(args):
289289
# acceleratorがなんかよろしくやってくれるらしい
290290
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
291291

292+
if isinstance(unet, DDP):
293+
unet._set_static_graph() # avoid error for multiple use of the parameter
294+
292295
if args.gradient_checkpointing:
293296
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
294297
else:

tools/cache_latents.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
ConfigSanitizer,
1717
BlueprintGenerator,
1818
)
19-
from library.utils import setup_logging
20-
19+
from library.utils import setup_logging, add_logging_arguments
2120
setup_logging()
2221
import logging
2322

2423
logger = logging.getLogger(__name__)
2524

2625

2726
def cache_to_disk(args: argparse.Namespace) -> None:
27+
setup_logging(args, reset=True)
2828
train_util.prepare_dataset_args(args, True)
2929

3030
# check cache latents arg
@@ -97,6 +97,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
9797

9898
# acceleratorを準備する
9999
logger.info("prepare accelerator")
100+
args.deepspeed = False
100101
accelerator = train_util.prepare_accelerator(args)
101102

102103
# mixed precisionに対応した型を用意しておき適宜castする
@@ -176,6 +177,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
176177
def setup_parser() -> argparse.ArgumentParser:
177178
parser = argparse.ArgumentParser()
178179

180+
add_logging_arguments(parser)
179181
train_util.add_sd_models_arguments(parser)
180182
train_util.add_training_arguments(parser, True)
181183
train_util.add_dataset_arguments(parser, True, True, True)

tools/cache_text_encoder_outputs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
ConfigSanitizer,
1717
BlueprintGenerator,
1818
)
19-
from library.utils import setup_logging
19+
from library.utils import setup_logging, add_logging_arguments
2020
setup_logging()
2121
import logging
2222
logger = logging.getLogger(__name__)
2323

2424
def cache_to_disk(args: argparse.Namespace) -> None:
25+
setup_logging(args, reset=True)
2526
train_util.prepare_dataset_args(args, True)
2627

2728
# check cache arg
@@ -99,6 +100,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
99100

100101
# acceleratorを準備する
101102
logger.info("prepare accelerator")
103+
args.deepspeed = False
102104
accelerator = train_util.prepare_accelerator(args)
103105

104106
# mixed precisionに対応した型を用意しておき適宜castする
@@ -171,6 +173,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
171173
def setup_parser() -> argparse.ArgumentParser:
172174
parser = argparse.ArgumentParser()
173175

176+
add_logging_arguments(parser)
174177
train_util.add_sd_models_arguments(parser)
175178
train_util.add_training_arguments(parser, True)
176179
train_util.add_dataset_arguments(parser, True, True, True)

0 commit comments

Comments
 (0)