Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,50 @@ def move_weight_functions(m, device):
return memory

class LowVramPatch:
def __init__(self, key, patches):
def __init__(self, key, patches, op):
self.key = key
self.patches = patches
convert_func = None
revert_func = None
op_keys = key.rsplit('.', 1)
if len(op_keys) >= 2:
try:
convert_func = getattr(op, "convert_{}".format(op_keys[1]))
except AttributeError:
pass
try:
revert_func = getattr(op, "revert_{}".format(op_keys[1]))
except AttributeError:
pass
self.convert_func = convert_func
self.revert_func = revert_func

def __call__(self, weight):
intermediate_dtype = weight.dtype
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
#intermediate_dtype has to be one that is supported in math ops
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]:
intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))

return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
final_dtype = weight.dtype

if weight.dtype != intermediate_dtype:
weight = weight.to(intermediate_dtype)
if self.convert_func is not None:
weight = self.convert_func(weight, inplace=True)
weight = comfy.lora.calculate_weight(
self.patches[self.key],
weight,
self.key,
intermediate_dtype=intermediate_dtype,
)
if self.revert_func is not None:
weight = self.revert_func(weight, inplace=True)
if weight.dtype != final_dtype:
weight = comfy.float.stochastic_rounding(
weight,
final_dtype,
seed=string_to_seed(self.key)
)
return weight

def get_key_weight(model, key):
set_func = None
Expand Down Expand Up @@ -657,13 +691,13 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = [LowVramPatch(weight_key, self.patches)]
m.weight_function = [LowVramPatch(weight_key, self.patches, m)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
m.bias_function = [LowVramPatch(bias_key, self.patches, m)]
patch_counter += 1

cast_weight = True
Expand Down Expand Up @@ -825,10 +859,10 @@ def partially_unload(self, device_to, memory_to_free=0):
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches))
m.weight_function.append(LowVramPatch(weight_key, self.patches, m))
patch_counter += 1
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
m.bias_function.append(LowVramPatch(bias_key, self.patches, m))
patch_counter += 1
cast_weight = True

Expand Down Expand Up @@ -1192,7 +1226,6 @@ def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_pat
if used:
target_device = weight.device
self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device)
# TODO: properly handle LowVramPatch, if it ends up an issue
temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
Expand Down
7 changes: 7 additions & 0 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,13 @@ def convert_weight(self, weight, inplace=False, **kwargs):
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)

def revert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight /= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype)

def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
Expand Down