diff --git a/sdxl_train.py b/sdxl_train.py index c7eea2224..be2b7166e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -357,27 +357,37 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print("prepare optimizer, data loader etc.") if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + # calculate total number of parameters n_total_params = sum(len(params["params"]) for params in params_to_optimize) params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - # split params into groups + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) grouped_params = [] param_group = [] param_group_lr = -1 for group in params_to_optimize: lr = group["lr"] for p in group["params"]: + # if the learning rate is different for different params, start a new group if lr != param_group_lr: if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = lr + param_group.append(p) + + # if the group has enough parameters, start a new group if len(param_group) == params_per_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = -1 + if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) @@ -388,7 +398,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - print(len(grouped_params)) logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") else: @@ -420,6 +429,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # lr schedulerを用意する if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code else: @@ -472,6 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) @@ -488,16 +499,20 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + # counters are used to determine when to step the optimizer global optimizer_hooked_count global num_parameters_per_group global parameter_optimizer_map + optimizer_hooked_count = {} num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: @@ -511,7 +526,7 @@ def optimizer_hook(parameter: torch.Tensor): optimizer_hooked_count[i] += 1 if optimizer_hooked_count[i] == num_parameters_per_group[i]: optimizers[i].step() - optimizers[i].zero_grad() + optimizers[i].zero_grad(set_to_none=True) parameter.register_post_accumulate_grad_hook(optimizer_hook) parameter_optimizer_map[parameter] = opt_idx @@ -593,7 +608,7 @@ def optimizer_hook(parameter: torch.Tensor): current_step.value = global_step if args.fused_optimizer_groups: - optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: @@ -725,14 +740,14 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - elif args.fused_optimizer_groups: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() - - lr_scheduler.step() - - if not (args.fused_backward_pass or args.fused_optimizer_groups): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: