Skip to content

Commit

Permalink
Merge pull request #1101 from BootsofLagrangian/deepspeed
Browse files Browse the repository at this point in the history
support deepspeed
  • Loading branch information
kohya-ss authored Feb 27, 2024
2 parents 074d32a + eefb3cc commit 0e4a573
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 56 deletions.
38 changes: 27 additions & 11 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def train(args):

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
Expand Down Expand Up @@ -158,7 +159,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
Expand Down Expand Up @@ -191,7 +192,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)

for m in training_models:
m.requires_grad_(True)
Expand All @@ -218,7 +219,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers,
)

Expand All @@ -230,7 +231,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)

Expand All @@ -246,13 +247,28 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
training_models_dict = {}
training_models_dict["unet"] = unet
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder

ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)

training_models = []
unet = ds_model.models["unet"]
training_models.append(unet)
if args.train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
training_models.append(text_encoder)

else: # acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down
1 change: 0 additions & 1 deletion library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
Expand Down
111 changes: 107 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
from accelerate import DeepSpeedPlugin
import glob
import math
import os
Expand Down Expand Up @@ -3242,6 +3243,52 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)

# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument(
"--zero_stage",
type=int, default=2,
choices=[0, 1, 2, 3],
help="Possible options are 0,1,2,3."
)
parser.add_argument(
"--offload_optimizer_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."
)

def verify_training_args(args: argparse.Namespace):
r"""
Expand Down Expand Up @@ -4088,17 +4135,76 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = prepare_deepspeed_plugin(args)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
print("accelerator device:", accelerator.device)
return accelerator

def prepare_deepspeed_plugin(args: argparse.Namespace):
if args.deepspeed is None: return None
try:
import deepspeed
except ImportError as e:
print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed")
exit(1)

deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True
print("[DeepSpeed] full fp16 enable.")
else:
print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.")

if args.offload_optimizer_device is not None:
print('[DeepSpeed] start to manually build cpu_adam.')
deepspeed.ops.op_builder.CPUAdamBuilder().load()
print('[DeepSpeed] building cpu_adam done.')

return deepspeed_plugin

def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(
torch.nn.ModuleDict(
{key: model}
)
)

def get_models(self):
return self.models

ds_model = DeepSpeedWrapper(**models)
return ds_model

def prepare_dtype(args: argparse.Namespace):
weight_dtype = torch.float32
Expand Down Expand Up @@ -4165,7 +4271,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une


def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
Expand All @@ -4176,7 +4281,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
Expand All @@ -4185,7 +4289,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio

clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format


Expand Down
62 changes: 43 additions & 19 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down Expand Up @@ -398,18 +398,41 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
training_models_dict = {}
if train_unet:
training_models_dict["unet"] = unet
if train_text_encoder1:
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
training_models_dict["text_encoder1"] = text_encoder1
if train_text_encoder2:
training_models_dict["text_encoder2"] = text_encoder2
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)

training_models = [] # override training_models
if train_unet:
unet = ds_model.models["unet"]
training_models.append(unet)
if train_text_encoder1:
text_encoder1 = ds_model.models["text_encoder1"]
training_models.append(text_encoder1)
if train_text_encoder2:
text_encoder2 = ds_model.models["text_encoder2"]
training_models.append(text_encoder2)

else: # acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand All @@ -423,7 +446,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2.to(accelerator.device)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
if args.full_fp16 and not args.deepspeed:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
Expand Down Expand Up @@ -484,18 +508,18 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
with torch.no_grad(): # why this block differ within train_network.py?
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR

if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
Expand Down
32 changes: 24 additions & 8 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def train(args):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down Expand Up @@ -219,15 +219,31 @@ def train(args):
text_encoder.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.deepspeed:
training_models_dict = {}
training_models_dict["unet"] = unet
if train_text_encoder: training_models_dict["text_encoder"] = text_encoder

ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)

training_models = []
unet = ds_model.models["unet"]
training_models.append(unet)
if train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
training_models.append(text_encoder)

else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down
Loading

0 comments on commit 0e4a573

Please sign in to comment.