Skip to content

Commit

Permalink
Refactor device determination to function; add MPS fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Jan 23, 2024
1 parent d676500 commit 32dca12
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 14 deletions.
3 changes: 2 additions & 1 deletion finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.device_utils import get_preferred_device

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()


IMAGE_SIZE = 384
Expand Down
4 changes: 2 additions & 2 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from transformers.generation.utils import GenerationMixin

import library.train_util as train_util
from library.device_utils import get_preferred_device


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()

PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
Expand Down
4 changes: 3 additions & 1 deletion finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import library.model_util as model_util
import library.train_util as train_util

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from library.device_utils import get_preferred_device

DEVICE = get_preferred_device()

IMAGE_TRANSFORMS = transforms.Compose(
[
Expand Down
4 changes: 2 additions & 2 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
import torch

from library.ipex_interop import init_ipex
from library.device_utils import clean_memory
from library.device_utils import clean_memory, get_preferred_device

init_ipex()

Expand Down Expand Up @@ -2324,7 +2324,7 @@ def __getattr__(self, item):
scheduler.config.clip_sample = True

# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
device = get_preferred_device()

# custom pipelineをコピったやつを生成する
if args.vae_slices:
Expand Down
29 changes: 27 additions & 2 deletions library/device_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
import functools
import gc

import torch

try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
HAS_CUDA = False

try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
HAS_MPS = False


def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()


@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device
4 changes: 3 additions & 1 deletion networks/lora_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from transformers import CLIPTextModel
import torch

from library.device_utils import get_preferred_device


def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
Expand Down Expand Up @@ -476,7 +478,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_preferred_device()

parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
Expand Down
3 changes: 2 additions & 1 deletion networks/lora_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import library.model_util as model_util
import lora
from library.device_utils import get_preferred_device

TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = get_preferred_device()


def interrogate(args):
Expand Down
4 changes: 2 additions & 2 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import torch

from library.device_utils import clean_memory
from library.device_utils import clean_memory, get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -1495,7 +1495,7 @@ def __getattr__(self, item):
# scheduler.config.clip_sample = True

# deviceを決定する
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
device = get_preferred_device()

# custom pipelineをコピったやつを生成する
if args.vae_slices:
Expand Down
3 changes: 2 additions & 1 deletion sdxl_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch

from library.device_utils import get_preferred_device
from library.ipex_interop import init_ipex

init_ipex()
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_timestep_embedding(x, outdim):
guidance_scale = 7
seed = None # 1

DEVICE = "cuda"
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work

parser = argparse.ArgumentParser()
Expand Down
4 changes: 3 additions & 1 deletion tools/latent_upscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from tqdm import tqdm
from PIL import Image

from library.device_utils import get_preferred_device


class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
Expand Down Expand Up @@ -255,7 +257,7 @@ def create_upscaler(**kwargs):

# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)

Expand Down

0 comments on commit 32dca12

Please sign in to comment.