Skip to content

Commit

Permalink
Ema (#4)
Browse files Browse the repository at this point in the history
* Add EMA support in train_network

* Update train_util.py

* Update train_util.py

* add to more scripts. fixes

* Update train_network.py

* Update fine_tune.py

* Update train_db.py

* Update train_network.py

* add info messages

* add sdxl_train

* cleanup

* cleanup

* cleanup

* fix crash when using te lr

* change default exp value

---------

Co-authored-by: vvern999 <[email protected]>
  • Loading branch information
donhardman and vvern999 authored Dec 3, 2023
1 parent a1db70d commit 1c179d5
Show file tree
Hide file tree
Showing 10 changed files with 346 additions and 3 deletions.
44 changes: 44 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
apply_debiased_estimation,
get_latent_masks
)
from library.train_util import EMAModel


def train(args):
Expand Down Expand Up @@ -246,6 +247,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

if args.enable_ema:
#ema_dtype = weight_dtype if (args.full_bf16 or args.full_fp16) else torch.float32
ema = EMAModel(trainable_params, decay=args.ema_decay, beta=args.ema_exp_beta, max_train_steps=args.max_train_steps)
ema.to(accelerator.device, dtype=weight_dtype)
ema = accelerator.prepare(ema)
else:
ema = None

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand Down Expand Up @@ -381,6 +390,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.enable_ema:
with torch.no_grad(), accelerator.autocast():
ema.step(trainable_params)

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down Expand Up @@ -435,6 +447,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
if args.enable_ema and not args.ema_save_only_ema_weights and ((epoch + 1) % args.save_every_n_epochs == 0):
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
True,
Expand All @@ -450,13 +465,34 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.unwrap_model(unet),
vae,
)
if args.enable_ema and ((epoch + 1) % args.save_every_n_epochs == 0):
args.output_name = temp_name if temp_name else args.output_name
with ema.ema_parameters(trainable_params):
print("Saving EMA:")
train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)

train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

is_main_process = accelerator.is_main_process
if is_main_process:
unet = accelerator.unwrap_model(unet)
text_encoder = accelerator.unwrap_model(text_encoder)
if args.enable_ema:
ema = accelerator.unwrap_model(ema)

accelerator.end_training()

Expand All @@ -467,6 +503,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
if args.enable_ema and not args.ema_save_only_ema_weights:
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae)
args.output_name = temp_name
if args.enable_ema:
print("Saving EMA:")
ema.copy_to(trainable_params)
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
Expand Down
22 changes: 22 additions & 0 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def save_sd_model_on_epoch_end_or_stepwise(
vae,
logit_scale,
ckpt_info,
ema = None,
params_to_replace = None,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
Expand Down Expand Up @@ -294,6 +296,10 @@ def diffusers_saver(out_dir):
save_dtype=save_dtype,
)

if args.enable_ema and not args.ema_save_only_ema_weights and ema:
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"

train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
Expand All @@ -306,6 +312,22 @@ def diffusers_saver(out_dir):
sd_saver,
diffusers_saver,
)
args.output_name = temp_name if temp_name else args.output_name
if args.enable_ema and ema:
with ema.ema_parameters(params_to_replace):
print("Saving EMA:")
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
save_stable_diffusion_format,
use_safetensors,
epoch,
num_train_epochs,
global_step,
sd_saver,
diffusers_saver,
)


def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
Expand Down
144 changes: 144 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Sequence,
Tuple,
Union,
Iterable,
)
from accelerate import Accelerator, InitProcessGroupKwargs
import gc
Expand Down Expand Up @@ -69,6 +70,7 @@
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel
import contextlib

# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -2345,6 +2347,135 @@ def load_text_encoder_outputs_from_disk(npz_path):

# endregion


# based mostly on https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py
class EMAModel:
"""
Maintains (exponential) moving average of a set of parameters.
"""
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay: float, beta=0, max_train_steps=10000):
parameters = self.get_params_list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.optimization_step = 0
self.collected_params = None
if beta < 0:
raise ValueError('ema_exp_beta should be > 0')
self.beta = beta
self.max_train_steps = max_train_steps
print(f"len(self.shadow_params): {len(self.shadow_params)}")

def get_params_list(self, parameters: Iterable[torch.nn.Parameter]):
parameters = list(parameters)
if isinstance(parameters[0], dict):
params_list = []
for m in parameters:
params_list.extend(list(m["params"]))
return params_list
else:
return parameters

def get_decay(self, optimization_step: int) -> float:
"""
Get current decay for the exponential moving average.
"""
if self.beta == 0:
return min(self.decay, (1 + optimization_step) / (10 + optimization_step))
else:
# exponential schedule. scales to max_train_steps
x = optimization_step / self.max_train_steps
return min(self.decay, self.decay * (1 - np.exp(-x * self.beta)))

def step(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
"""
parameters = self.get_params_list(parameters)
one_minus_decay = 1.0 - self.get_decay(self.optimization_step)
self.optimization_step += 1
#print(f" {one_minus_decay}")
#with torch.no_grad():
for s_param, param in zip(self.shadow_params, parameters, strict=True):
tmp = (s_param - param)
#print(torch.sum(tmp))
# tmp will be a new tensor so we can do in-place
tmp.mul_(one_minus_decay)
s_param.sub_(tmp)

def copy_to(self, parameters: Iterable[torch.nn.Parameter] = None) -> None:
"""
Copy current averaged parameters into given collection of parameters.
"""
parameters = self.get_params_list(parameters)
for s_param, param in zip(self.shadow_params, parameters, strict=True):
# print(f"diff: {torch.sum(s_param) - torch.sum(param)}")
param.data.copy_(s_param.data)

def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
"""
self.shadow_params = [
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
for p in self.shadow_params
]
return

def store(self, parameters: Iterable[torch.nn.Parameter] = None) -> None:
"""
Save the current parameters for restoring later.
"""
parameters = self.get_params_list(parameters)
self.collected_params = [
param.clone()
for param in parameters
]

def restore(self, parameters: Iterable[torch.nn.Parameter] = None) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
"""
if self.collected_params is None:
raise RuntimeError(
"This ExponentialMovingAverage has no `store()`ed weights "
"to `restore()`"
)
parameters = self.get_params_list(parameters)
for c_param, param in zip(self.collected_params, parameters, strict=True):
param.data.copy_(c_param.data)

@contextlib.contextmanager
def ema_parameters(self, parameters: Iterable[torch.nn.Parameter] = None):
r"""
Context manager for validation/inference with averaged parameters.
Equivalent to:
ema.store()
ema.copy_to()
try:
...
finally:
ema.restore()
"""
parameters = self.get_params_list(parameters)
self.store(parameters)
self.copy_to(parameters)
try:
yield
finally:
self.restore(parameters)


# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
Expand Down Expand Up @@ -2990,6 +3121,19 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
# default=None,
# help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
# )
parser.add_argument(
"--enable_ema", action="store_true", help="Enable EMA (Exponential Moving Average) of model parameters / モデルパラメータのEMA(指数移動平均)を有効にする "
)
parser.add_argument(
"--ema_decay", type=float, default=0.999, help="Max EMA decay. Typical values: 0.999 - 0.9999 / 最大EMA減衰。標準的な値: 0.999 - 0.9999 "
)
parser.add_argument(
"--ema_exp_beta", type=float, default=15, help="Choose EMA decay schedule. By default: (1+x)/(10+x). If beta is set: use exponential schedule scaled to max_train_steps. If beta>0, recommended values are around 10-15 "
+ "/ EMAの減衰スケジュールを設定する。デフォルト:(1+x)/(10+x)。beta が設定されている場合: max_train_steps にスケーリングされた指数スケジュールを使用する。beta>0 の場合、推奨値は 10-15 程度。 "
)
parser.add_argument(
"--ema_save_only_ema_weights", action="store_true", help="By default both EMA and non-EMA weights are saved. If enabled, saves only EMA / デフォルトでは、EMAウェイトと非EMAウェイトの両方が保存される。有効にすると、EMAのみが保存される "
)
parser.add_argument(
"--multires_noise_discount",
type=float,
Expand Down
40 changes: 40 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_latent_masks
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
from library.train_util import EMAModel


UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
Expand Down Expand Up @@ -386,6 +387,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)

if args.enable_ema:
#ema_dtype = weight_dtype if (args.full_bf16 or args.full_fp16) else torch.float
ema = EMAModel(params_to_optimize, decay=args.ema_decay, beta=args.ema_exp_beta, max_train_steps=args.max_train_steps)
ema.to(accelerator.device, dtype=weight_dtype)
ema = accelerator.prepare(ema)
else:
ema = None
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
Expand Down Expand Up @@ -587,6 +595,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.enable_ema:
with torch.no_grad(), accelerator.autocast():
ema.step(params_to_optimize)

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down Expand Up @@ -627,6 +638,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
logit_scale,
ckpt_info,
ema=ema,
params_to_replace=params_to_optimize,
)

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
Expand Down Expand Up @@ -673,6 +686,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
vae,
logit_scale,
ckpt_info,
ema=ema,
params_to_replace=params_to_optimize,
)

sdxl_train_util.sample_images(
Expand All @@ -692,6 +707,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet = accelerator.unwrap_model(unet)
text_encoder1 = accelerator.unwrap_model(text_encoder1)
text_encoder2 = accelerator.unwrap_model(text_encoder2)
if args.enable_ema:
ema = accelerator.unwrap_model(ema)

accelerator.end_training()

Expand All @@ -702,6 +719,29 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
if args.enable_ema and not args.ema_save_only_ema_weights:
temp_name = args.output_name
args.output_name = args.output_name + "-non-EMA"
sdxl_train_util.save_sd_model_on_train_end(
args,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
global_step,
text_encoder1,
text_encoder2,
unet,
vae,
logit_scale,
ckpt_info,
)
args.output_name = temp_name
if args.enable_ema:
print("Saving EMA:")
ema.copy_to(params_to_optimize)

sdxl_train_util.save_sd_model_on_train_end(
args,
src_path,
Expand Down
Loading

0 comments on commit 1c179d5

Please sign in to comment.