Skip to content

Commit

Permalink
Merge pull request #1054 from akx/mps
Browse files Browse the repository at this point in the history
Device support improvements (MPS)
  • Loading branch information
kohya-ss authored Jan 31, 2024
2 parents 7f948db + 478156b commit 2ca4d0c
Show file tree
Hide file tree
Showing 22 changed files with 91 additions and 75 deletions.
6 changes: 2 additions & 4 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# XXX dropped option: hypernetwork training

import argparse
import gc
import math
import os
from multiprocessing import Value
Expand All @@ -11,6 +10,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -158,9 +158,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down
3 changes: 2 additions & 1 deletion finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.device_utils import get_preferred_device

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()


IMAGE_SIZE = 384
Expand Down
4 changes: 2 additions & 2 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util
from library.device_utils import get_preferred_device


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()

PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
Expand Down
4 changes: 3 additions & 1 deletion finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import library.model_util as model_util
import library.train_util as train_util

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from library.device_utils import get_preferred_device

DEVICE = get_preferred_device()

IMAGE_TRANSFORMS = transforms.Compose(
[
Expand Down
9 changes: 4 additions & 5 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import numpy as np
import torch

from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -888,8 +889,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
Expand Down Expand Up @@ -1047,8 +1047,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
Expand Down Expand Up @@ -2325,7 +2324,7 @@ def __getattr__(self, item):
scheduler.config.clip_sample = True

# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
device = get_preferred_device()

# custom pipelineをコピったやつを生成する
if args.vae_slices:
Expand Down
34 changes: 34 additions & 0 deletions library/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import functools
import gc

import torch

try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
HAS_CUDA = False

try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
HAS_MPS = False


def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()


@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device
5 changes: 2 additions & 3 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import math
import os
from typing import Optional
Expand All @@ -8,6 +7,7 @@
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.device_utils import clean_memory
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline

TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
Expand Down
11 changes: 5 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
import os
Expand Down Expand Up @@ -67,6 +66,7 @@

# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel

# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
Expand Down Expand Up @@ -2279,8 +2279,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent

# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()


def cache_batch_text_encoder_outputs(
Expand Down Expand Up @@ -3920,6 +3919,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
)
print("accelerator device:", accelerator.device)
return accelerator


Expand Down Expand Up @@ -4006,8 +4006,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format
Expand Down Expand Up @@ -4816,7 +4815,7 @@ def sample_images_common(

# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
clean_memory()

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down
4 changes: 3 additions & 1 deletion networks/lora_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from transformers import CLIPTextModel
import torch

from library.device_utils import get_preferred_device


def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
Expand Down Expand Up @@ -476,7 +478,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_preferred_device()

parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
Expand Down
3 changes: 2 additions & 1 deletion networks/lora_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import library.model_util as model_util
import lora
from library.device_utils import get_preferred_device

TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = get_preferred_device()


def interrogate(args):
Expand Down
12 changes: 5 additions & 7 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import torch

from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -640,8 +641,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
Expand Down Expand Up @@ -780,8 +780,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents.to(self.vae.dtype)).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
Expand All @@ -796,8 +795,7 @@ def __call__(
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

if output_type == "pil":
# image = self.numpy_to_pil(image)
Expand Down Expand Up @@ -1497,7 +1495,7 @@ def __getattr__(self, item):
# scheduler.config.clip_sample = True

# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
device = get_preferred_device()

# custom pipelineをコピったやつを生成する
if args.vae_slices:
Expand Down
3 changes: 2 additions & 1 deletion sdxl_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch

from library.device_utils import get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_timestep_embedding(x, outdim):
guidance_scale = 7
seed = None # 1

DEVICE = "cuda"
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work

parser = argparse.ArgumentParser()
Expand Down
9 changes: 3 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# training with captions

import argparse
import gc
import math
import os
from multiprocessing import Value
Expand All @@ -11,6 +10,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -252,9 +252,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -407,8 +405,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
Expand Down
9 changes: 3 additions & 6 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward

import argparse
import gc
import json
import math
import os
Expand All @@ -15,6 +14,7 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -164,9 +164,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -291,8 +289,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
Expand Down
Loading

0 comments on commit 2ca4d0c

Please sign in to comment.