From 4f203ce40d3a4647d52a2570a228e279dd04b321 Mon Sep 17 00:00:00 2001 From: 2kpr <96332338+2kpr@users.noreply.github.com> Date: Sun, 14 Apr 2024 09:56:58 -0500 Subject: [PATCH 1/6] Fused backward pass --- library/adafactor_fused.py | 106 +++++++++++++++++++++++++++++++++++++ library/train_util.py | 13 +++++ sdxl_train.py | 29 +++++++--- 3 files changed, 142 insertions(+), 6 deletions(-) create mode 100644 library/adafactor_fused.py diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py new file mode 100644 index 000000000..bdfc32ced --- /dev/null +++ b/library/adafactor_fused.py @@ -0,0 +1,106 @@ +import math +import torch +from transformers import Adafactor + +@torch.no_grad() +def adafactor_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = Adafactor._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = Adafactor._rms(p_data_fp32) + lr = Adafactor._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + +@torch.no_grad() +def adafactor_step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + adafactor_step_param(self, p, group) + + return loss + +def patch_adafactor_fused(optimizer: Adafactor): + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..46b55c03e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,6 +2920,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument( + "--fused_backward_pass", + action="store_true", + help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -3846,6 +3851,14 @@ def get_optimizer(args, trainable_params): optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() + if args.fused_backward_pass: + assert ( + optimizer_type == "Adafactor".lower() + ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + assert ( + args.gradient_accumulation_steps == 1 + ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" + # 引数を分解する optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..3b28575ed 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -430,6 +430,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.fused_backward_pass: + import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -619,13 +633,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) From 017b82ebe33a2199c8f842c99905f59c54292f56 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 15:05:42 +0900 Subject: [PATCH 2/6] update help message for fused_backward_pass --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 46b55c03e..e3c0229a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2923,7 +2923,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--fused_backward_pass", action="store_true", - help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", ) From b56d5f7801dea45cdbbba8498544e8d2853ad6d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 21:35:39 +0900 Subject: [PATCH 3/6] add experimental option to fuse params to optimizer groups --- sdxl_train.py | 114 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 3b28575ed..c7eea2224 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # calculate number of trainable parameters n_params = 0 - for params in params_to_optimize: - for p in params["params"]: + for group in params_to_optimize: + for p in group["params"]: n_params += p.numel() accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") @@ -355,7 +355,44 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + if args.fused_optimizer_groups: + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + param_group.append(p) + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + print(len(grouped_params)) + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 @@ -382,7 +419,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.fused_optimizer_groups: + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -432,10 +473,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -444,6 +487,36 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad() + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -518,6 +591,10 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -596,7 +673,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -614,7 +693,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -630,11 +711,13 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) - if not args.fused_backward_pass: + if not (args.fused_backward_pass or args.fused_optimizer_groups): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -642,9 +725,14 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -753,7 +841,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) return parser From 607e041f3de972f2c3030e7c8b43dfc3c2eb2d65 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 14:16:41 +0900 Subject: [PATCH 4/6] chore: Refactor optimizer group --- sdxl_train.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index c7eea2224..be2b7166e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -357,27 +357,37 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print("prepare optimizer, data loader etc.") if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + # calculate total number of parameters n_total_params = sum(len(params["params"]) for params in params_to_optimize) params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - # split params into groups + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) grouped_params = [] param_group = [] param_group_lr = -1 for group in params_to_optimize: lr = group["lr"] for p in group["params"]: + # if the learning rate is different for different params, start a new group if lr != param_group_lr: if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = lr + param_group.append(p) + + # if the group has enough parameters, start a new group if len(param_group) == params_per_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = -1 + if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) @@ -388,7 +398,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - print(len(grouped_params)) logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") else: @@ -420,6 +429,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # lr schedulerを用意する if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code else: @@ -472,6 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) @@ -488,16 +499,20 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + # counters are used to determine when to step the optimizer global optimizer_hooked_count global num_parameters_per_group global parameter_optimizer_map + optimizer_hooked_count = {} num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: @@ -511,7 +526,7 @@ def optimizer_hook(parameter: torch.Tensor): optimizer_hooked_count[i] += 1 if optimizer_hooked_count[i] == num_parameters_per_group[i]: optimizers[i].step() - optimizers[i].zero_grad() + optimizers[i].zero_grad(set_to_none=True) parameter.register_post_accumulate_grad_hook(optimizer_hook) parameter_optimizer_map[parameter] = opt_idx @@ -593,7 +608,7 @@ def optimizer_hook(parameter: torch.Tensor): current_step.value = global_step if args.fused_optimizer_groups: - optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: @@ -725,14 +740,14 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - elif args.fused_optimizer_groups: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() - - lr_scheduler.step() - - if not (args.fused_backward_pass or args.fused_optimizer_groups): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: From f3d2cf22ff9ad49e7f8bd68494714fa3bedbd77d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 15:03:02 +0900 Subject: [PATCH 5/6] update README for fused optimizer --- README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/README.md b/README.md index 859a7618d..4fd97fb25 100644 --- a/README.md +++ b/README.md @@ -139,8 +139,37 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! + - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. + - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. + - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. + +- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. + - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. + - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. + - Fixed some bugs when using DeepSpeed. Related [#1247] +- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 + - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 + - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 + - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。 + +- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 + - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 + - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 + - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。 - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] From bee8cee7e8fbeecc05b1c80a1e9e8fadab3210a5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 15:08:52 +0900 Subject: [PATCH 6/6] update README for fused optimizer --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4fd97fb25..9c7ecad99 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. - - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts. - Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. @@ -162,14 +162,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 - - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。 - SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 - - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247]