diff --git a/README.md b/README.md index 859a7618d..9c7ecad99 100644 --- a/README.md +++ b/README.md @@ -139,8 +139,37 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! + - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. + - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. + - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts. + +- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. + - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. + - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. + - Fixed some bugs when using DeepSpeed. Related [#1247] +- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 + - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 + - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 + - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。 + +- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 + - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 + - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 + - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py new file mode 100644 index 000000000..bdfc32ced --- /dev/null +++ b/library/adafactor_fused.py @@ -0,0 +1,106 @@ +import math +import torch +from transformers import Adafactor + +@torch.no_grad() +def adafactor_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = Adafactor._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = Adafactor._rms(p_data_fp32) + lr = Adafactor._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + +@torch.no_grad() +def adafactor_step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + adafactor_step_param(self, p, group) + + return loss + +def patch_adafactor_fused(optimizer: Adafactor): + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..e3c0229a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,6 +2920,12 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument( + "--fused_backward_pass", + action="store_true", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -3846,6 +3852,14 @@ def get_optimizer(args, trainable_params): optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() + if args.fused_backward_pass: + assert ( + optimizer_type == "Adafactor".lower() + ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + assert ( + args.gradient_accumulation_steps == 1 + ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" + # 引数を分解する optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..be2b7166e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # calculate number of trainable parameters n_params = 0 - for params in params_to_optimize: - for p in params["params"]: + for group in params_to_optimize: + for p in group["params"]: n_params += p.numel() accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") @@ -355,7 +355,53 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + # if the learning rate is different for different params, start a new group + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + + param_group.append(p) + + # if the group has enough parameters, start a new group + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 @@ -382,7 +428,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -430,6 +481,57 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future + import library.adafactor_fused + + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + + elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + # counters are used to determine when to step the optimizer + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad(set_to_none=True) + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -504,6 +606,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -582,7 +688,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -600,7 +708,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -616,18 +726,28 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if not (args.fused_backward_pass or args.fused_optimizer_groups): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -736,7 +856,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -805,6 +925,12 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) return parser