Skip to content

Commit

Permalink
chore: Refactor optimizer group
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed May 12, 2024
1 parent b56d5f7 commit 607e041
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,27 +357,37 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print("prepare optimizer, data loader etc.")

if args.fused_optimizer_groups:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
# This balances memory usage and management complexity.

# calculate total number of parameters
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)

# split params into groups
# split params into groups, keeping the learning rate the same for all params in a group
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
grouped_params = []
param_group = []
param_group_lr = -1
for group in params_to_optimize:
lr = group["lr"]
for p in group["params"]:
# if the learning rate is different for different params, start a new group
if lr != param_group_lr:
if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = lr

param_group.append(p)

# if the group has enough parameters, start a new group
if len(param_group) == params_per_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = -1

if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})

Expand All @@ -388,7 +398,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code

print(len(grouped_params))
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")

else:
Expand Down Expand Up @@ -420,6 +429,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# lr schedulerを用意する
if args.fused_optimizer_groups:
# prepare lr schedulers for each optimizer
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
Expand Down Expand Up @@ -472,6 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused

library.adafactor_fused.patch_adafactor_fused(optimizer)
Expand All @@ -488,16 +499,20 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
parameter.register_post_accumulate_grad_hook(__grad_hook)

elif args.fused_optimizer_groups:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])

# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map

optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}

for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
Expand All @@ -511,7 +526,7 @@ def optimizer_hook(parameter: torch.Tensor):
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad()
optimizers[i].zero_grad(set_to_none=True)

parameter.register_post_accumulate_grad_hook(optimizer_hook)
parameter_optimizer_map[parameter] = opt_idx
Expand Down Expand Up @@ -593,7 +608,7 @@ def optimizer_hook(parameter: torch.Tensor):
current_step.value = global_step

if args.fused_optimizer_groups:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))}
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step

with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
Expand Down Expand Up @@ -725,14 +740,14 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
elif args.fused_optimizer_groups:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

lr_scheduler.step()

if not (args.fused_backward_pass or args.fused_optimizer_groups):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.fused_optimizer_groups:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down

0 comments on commit 607e041

Please sign in to comment.