Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support deepspeed #1101

Merged
merged 12 commits into from
Feb 27, 2024
38 changes: 27 additions & 11 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def train(args):

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
Expand Down Expand Up @@ -158,7 +159,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
Expand Down Expand Up @@ -191,7 +192,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)

for m in training_models:
m.requires_grad_(True)
Expand All @@ -218,7 +219,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers,
)

Expand All @@ -230,7 +231,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)

Expand All @@ -246,13 +247,28 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
training_models_dict = {}
training_models_dict["unet"] = unet
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder

ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)

training_models = []
unet = ds_model.models["unet"]
training_models.append(unet)
if args.train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
training_models.append(text_encoder)

else: # acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down
1 change: 0 additions & 1 deletion library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
Expand Down
111 changes: 107 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
from accelerate import DeepSpeedPlugin
import glob
import math
import os
Expand Down Expand Up @@ -3242,6 +3243,52 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)

# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument(
"--zero_stage",
type=int, default=2,
choices=[0, 1, 2, 3],
help="Possible options are 0,1,2,3."
)
parser.add_argument(
"--offload_optimizer_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."
)

def verify_training_args(args: argparse.Namespace):
r"""
Expand Down Expand Up @@ -4088,17 +4135,76 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = prepare_deepspeed_plugin(args)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
print("accelerator device:", accelerator.device)
return accelerator

def prepare_deepspeed_plugin(args: argparse.Namespace):
if args.deepspeed is None: return None
try:
import deepspeed
except ImportError as e:
print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed")
exit(1)

deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True
print("[DeepSpeed] full fp16 enable.")
else:
print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.")

if args.offload_optimizer_device is not None:
print('[DeepSpeed] start to manually build cpu_adam.')
deepspeed.ops.op_builder.CPUAdamBuilder().load()
print('[DeepSpeed] building cpu_adam done.')

return deepspeed_plugin

def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(
torch.nn.ModuleDict(
{key: model}
)
)

def get_models(self):
return self.models

ds_model = DeepSpeedWrapper(**models)
return ds_model

def prepare_dtype(args: argparse.Namespace):
weight_dtype = torch.float32
Expand Down Expand Up @@ -4165,7 +4271,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une


def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
Expand All @@ -4176,7 +4281,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
Expand All @@ -4185,7 +4289,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio

clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format


Expand Down
62 changes: 43 additions & 19 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down Expand Up @@ -398,18 +398,41 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
training_models_dict = {}
if train_unet:
training_models_dict["unet"] = unet
if train_text_encoder1:
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
training_models_dict["text_encoder1"] = text_encoder1
if train_text_encoder2:
training_models_dict["text_encoder2"] = text_encoder2
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)

training_models = [] # override training_models
if train_unet:
unet = ds_model.models["unet"]
training_models.append(unet)
if train_text_encoder1:
text_encoder1 = ds_model.models["text_encoder1"]
training_models.append(text_encoder1)
if train_text_encoder2:
text_encoder2 = ds_model.models["text_encoder2"]
training_models.append(text_encoder2)

else: # acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand All @@ -423,7 +446,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2.to(accelerator.device)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
if args.full_fp16 and not args.deepspeed:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
Expand Down Expand Up @@ -484,18 +508,18 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
with torch.no_grad(): # why this block differ within train_network.py?
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR

if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
Expand Down
32 changes: 24 additions & 8 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def train(args):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down Expand Up @@ -219,15 +219,31 @@ def train(args):
text_encoder.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.deepspeed:
training_models_dict = {}
training_models_dict["unet"] = unet
if train_text_encoder: training_models_dict["text_encoder"] = text_encoder

ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)

training_models = []
unet = ds_model.models["unet"]
training_models.append(unet)
if train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
training_models.append(text_encoder)

else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down
Loading