Skip to content

Commit

Permalink
Merge branch 'master' into feat/any_from_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
melMass committed Jun 5, 2024
2 parents df03d4c + b1fd26f commit 3af37b1
Show file tree
Hide file tree
Showing 76 changed files with 310 additions and 17,806 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph |
| Ctrl + Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
Expand Down Expand Up @@ -106,7 +106,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins

This is the command to install the nightly with ROCm 6.0 which might have some performance improvements:

```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.1```

### NVIDIA

Expand All @@ -116,7 +116,7 @@ Nvidia users should install stable pytorch using this command:

This is the command to install pytorch nightly instead which might have performance improvements:

```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```

#### Troubleshooting

Expand Down
1 change: 0 additions & 1 deletion comfy/ldm/cascade/stage_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import math
import numpy as np
import torch
from torch import nn
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
Expand Down
1 change: 0 additions & 1 deletion comfy/ldm/cascade/stage_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
from torch import nn
import numpy as np
import math
from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
# from .controlnet import ControlNetDeliverer
Expand Down
2 changes: 0 additions & 2 deletions comfy/ldm/models/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import torch
# import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down
19 changes: 15 additions & 4 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from typing import Optional
import logging

from .diffusionmodules.util import AlphaBlender, timestep_embedding
Expand All @@ -19,12 +19,13 @@
import comfy.ops
ops = comfy.ops.disable_weight_init

FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()

def get_attn_precision(attn_precision):
if args.dont_upcast_attention:
return None
if attn_precision is None and args.force_upcast_attention:
return torch.float32
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
return FORCE_UPCAST_ATTENTION_DTYPE
return attn_precision

def exists(val):
Expand Down Expand Up @@ -313,9 +314,19 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads

disabled_xformers = False

if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
disabled_xformers = True

if not disabled_xformers:
if torch.jit.is_tracing() or torch.jit.is_scripting():
disabled_xformers = True

if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask)

q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
Expand Down
31 changes: 28 additions & 3 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import logging
from enum import Enum
from comfy.cli_args import args
import comfy.utils
import torch
import sys
import platform

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
Expand Down Expand Up @@ -629,8 +629,18 @@ def supports_dtype(device, dtype): #TODO
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
return False
return True

def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes issues
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others


def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
Expand All @@ -642,7 +652,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True

non_blocking = device_supports_non_blocking(device)
non_blocking = device_should_use_non_blocking(device)

if device_supports_cast:
if copy:
Expand Down Expand Up @@ -683,8 +693,22 @@ def pytorch_attention_flash_attention():
#TODO: more reliable way of checking for flash attention?
if is_nvidia(): #pytorch flash attention only works on Nvidia
return True
if is_intel_xpu():
return True
return False

def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
upcast = True
except:
pass
if upcast:
return torch.float32
else:
return None

def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
if dev is None:
Expand Down Expand Up @@ -857,6 +881,7 @@ def unload_all_models():


def resolve_lowvram_weight(weight, model, key): #TODO: remove
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
return weight

#TODO: might be cleaner to put this somewhere else
Expand Down
78 changes: 53 additions & 25 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,29 @@

import comfy.utils
import comfy.model_management
from comfy.types import UnetWrapperFunction

def apply_weight_decompose(dora_scale, weight):

def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)

return weight * (dora_scale / weight_norm).type(weight.dtype)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight


def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
Expand Down Expand Up @@ -64,9 +76,7 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No
def model_size(self):
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
self.size = comfy.model_management.module_size(self.model)
self.model_keys = set(model_sd.keys())
return self.size

def clone(self):
Expand All @@ -78,7 +88,6 @@ def clone(self):

n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
return n
Expand Down Expand Up @@ -117,7 +126,7 @@ def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_op
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True

def set_model_unet_function_wrapper(self, unet_wrapper_function):
def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction):
self.model_options["model_function_wrapper"] = unet_wrapper_function

def set_model_denoise_mask_function(self, denoise_mask_function):
Expand Down Expand Up @@ -198,8 +207,9 @@ def model_dtype(self):

def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
model_sd = self.model.state_dict()
for k in patches:
if k in self.model_keys:
if k in model_sd:
p.add(k)
current_patches = self.patches.get(k, [])
current_patches.append((strength_patch, patches[k], strength_model))
Expand Down Expand Up @@ -326,7 +336,7 @@ def __call__(self, weight):

def calculate_weight(self, patches, weight, key):
for p in patches:
alpha = p[0]
strength = p[0]
v = p[1]
strength_model = p[2]

Expand All @@ -344,26 +354,31 @@ def calculate_weight(self, patches, weight, key):

if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
weight += strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0

if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
Expand Down Expand Up @@ -400,19 +415,26 @@ def calculate_weight(self, patches, weight, key):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
alpha = v[2] / dim
else:
alpha = 1.0

try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha *= v[2] / w1b.shape[0]
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0

w2a = v[3]
w2b = v[4]
dora_scale = v[7]
Expand All @@ -435,14 +457,18 @@ def calculate_weight(self, patches, weight, key):
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))

try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0

dora_scale = v[5]

Expand All @@ -452,9 +478,11 @@ def calculate_weight(self, patches, weight, key):
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength)
else:
weight += ((strength * alpha) * lora_diff).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
Expand Down
2 changes: 1 addition & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None:
Expand Down
1 change: 0 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from . import clip_vision
from . import gligen
from . import diffusers_convert
from . import model_base
from . import model_detection

from . import sd1_clip
Expand Down
1 change: 0 additions & 1 deletion comfy/sd2_clip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from comfy import sd1_clip
import torch
import os

class SD2ClipHModel(sd1_clip.SDClipModel):
Expand Down
Loading

0 comments on commit 3af37b1

Please sign in to comment.