Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep speed #1139

Merged
merged 22 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
dfe08f3
support deepspeed
BootsofLagrangian Feb 3, 2024
64873c1
fix offload_optimizer_device typo
BootsofLagrangian Feb 5, 2024
2824312
fix vae type error during training sdxl
BootsofLagrangian Feb 5, 2024
4295f91
fix all trainer about vae
BootsofLagrangian Feb 5, 2024
3970bf4
maybe fix branch to run offloading
BootsofLagrangian Feb 5, 2024
7d2a926
apply offloading method runable for all trainer
BootsofLagrangian Feb 5, 2024
6255661
fix full_fp16 compatible and train_step
BootsofLagrangian Feb 7, 2024
2445a5b
remove test requirements
BootsofLagrangian Feb 7, 2024
a98feca
forgot setting mixed_precision for deepspeed. sorry
BootsofLagrangian Feb 7, 2024
03f0816
the reason not working grad accum steps found. it was becasue of my a…
BootsofLagrangian Feb 9, 2024
4d5186d
refactored codes, some function moved into train_utils.py
BootsofLagrangian Feb 22, 2024
eefb3cc
Merge branch 'deep-speed' into deepspeed
kohya-ss Feb 27, 2024
0e4a573
Merge pull request #1101 from BootsofLagrangian/deepspeed
kohya-ss Feb 27, 2024
e3ccf8f
make deepspeed_utils
kohya-ss Feb 27, 2024
97524f1
Merge branch 'dev' into deep-speed
kohya-ss Mar 12, 2024
86e40fa
Merge branch 'dev' into deep-speed
kohya-ss Mar 17, 2024
fbb98f1
Merge branch 'dev' into deep-speed
kohya-ss Mar 20, 2024
d945602
Fix most of ZeRO stage uses optimizer partitioning
BootsofLagrangian Mar 20, 2024
a35e7bd
Merge pull request #1200 from BootsofLagrangian/deep-speed
kohya-ss Mar 20, 2024
993b2ab
Merge branch 'dev' into deep-speed
kohya-ss Mar 24, 2024
c24422f
Merge branch 'dev' into deep-speed
kohya-ss Mar 25, 2024
a2b8531
make each script consistent, fix to work w/o DeepSpeed
kohya-ss Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from accelerate.utils import set_seed
Expand Down Expand Up @@ -42,6 +44,7 @@
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)

cache_latents = args.cache_latents
Expand Down Expand Up @@ -108,6 +111,7 @@ def train(args):

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
Expand Down Expand Up @@ -158,7 +162,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
Expand Down Expand Up @@ -191,7 +195,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)

for m in training_models:
m.requires_grad_(True)
Expand Down Expand Up @@ -246,13 +250,23 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -311,13 +325,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]

Expand Down Expand Up @@ -477,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
Expand All @@ -492,6 +507,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)

return parser

Expand Down
139 changes: 139 additions & 0 deletions library/deepspeed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import argparse
import torch
from accelerate import DeepSpeedPlugin, Accelerator

from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


def add_deepspeed_arguments(parser: argparse.ArgumentParser):
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
parser.add_argument(
"--offload_optimizer_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
)


def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
return

# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
args.max_data_loader_n_workers = 1


def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
return None

try:
import deepspeed
except ImportError as e:
logger.error(
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
)
exit(1)

deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device,
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device,
offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag,
zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
logger.info("[DeepSpeed] full fp16 enable.")
else:
logger.info(
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
)

if args.offload_optimizer_device is not None:
logger.info("[DeepSpeed] start to manually build cpu_adam.")
deepspeed.ops.op_builder.CPUAdamBuilder().load()
logger.info("[DeepSpeed] building cpu_adam done.")

return deepspeed_plugin


# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
def prepare_deepspeed_model(args: argparse.Namespace, **models):
# remove None from models
models = {k: v for k, v in models.items() if v is not None}

class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))

def get_models(self):
return self.models

ds_model = DeepSpeedWrapper(**models)
return ds_model
1 change: 0 additions & 1 deletion library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
Expand Down
11 changes: 8 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging

setup_logging()
Expand Down Expand Up @@ -4095,6 +4096,10 @@ def load_tokenizer(args: argparse.Namespace):


def prepare_accelerator(args: argparse.Namespace):
"""
this function also prepares deepspeed plugin
"""

if args.logging_dir is None:
logging_dir = None
else:
Expand Down Expand Up @@ -4140,13 +4145,16 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
print("accelerator device:", accelerator.device)
return accelerator
Expand Down Expand Up @@ -4217,7 +4225,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une


def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
Expand All @@ -4228,7 +4235,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
Expand All @@ -4237,7 +4243,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio

clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format


Expand Down
38 changes: 29 additions & 9 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

import torch
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
from library import deepspeed_utils, sdxl_model_util

import library.train_util as train_util

Expand Down Expand Up @@ -97,6 +98,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)

assert (
Expand Down Expand Up @@ -398,18 +400,33 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoder1 if train_text_encoder1 else None,
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand All @@ -424,6 +441,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
kohya-ss marked this conversation as resolved.
Show resolved Hide resolved
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
Expand Down Expand Up @@ -744,6 +763,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
Expand Down
Loading