diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1fd03d9d16a1..79e697be4e49 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -123,16 +123,50 @@ def move_weight_functions(m, device): return memory class LowVramPatch: - def __init__(self, key, patches): + def __init__(self, key, patches, op): self.key = key self.patches = patches + convert_func = None + revert_func = None + op_keys = key.rsplit('.', 1) + if len(op_keys) >= 2: + try: + convert_func = getattr(op, "convert_{}".format(op_keys[1])) + except AttributeError: + pass + try: + revert_func = getattr(op, "revert_{}".format(op_keys[1])) + except AttributeError: + pass + self.convert_func = convert_func + self.revert_func = revert_func + def __call__(self, weight): intermediate_dtype = weight.dtype - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops + #intermediate_dtype has to be one that is supported in math ops + if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: intermediate_dtype = torch.float32 - return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) - - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) + final_dtype = weight.dtype + + if weight.dtype != intermediate_dtype: + weight = weight.to(intermediate_dtype) + if self.convert_func is not None: + weight = self.convert_func(weight, inplace=True) + weight = comfy.lora.calculate_weight( + self.patches[self.key], + weight, + self.key, + intermediate_dtype=intermediate_dtype, + ) + if self.revert_func is not None: + weight = self.revert_func(weight, inplace=True) + if weight.dtype != final_dtype: + weight = comfy.float.stochastic_rounding( + weight, + final_dtype, + seed=string_to_seed(self.key) + ) + return weight def get_key_weight(model, key): set_func = None @@ -657,13 +691,13 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False if force_patch_weights: self.patch_weight_to_device(weight_key) else: - m.weight_function = [LowVramPatch(weight_key, self.patches)] + m.weight_function = [LowVramPatch(weight_key, self.patches, m)] patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: - m.bias_function = [LowVramPatch(bias_key, self.patches)] + m.bias_function = [LowVramPatch(bias_key, self.patches, m)] patch_counter += 1 cast_weight = True @@ -825,10 +859,10 @@ def partially_unload(self, device_to, memory_to_free=0): module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - m.weight_function.append(LowVramPatch(weight_key, self.patches)) + m.weight_function.append(LowVramPatch(weight_key, self.patches, m)) patch_counter += 1 if bias_key in self.patches: - m.bias_function.append(LowVramPatch(bias_key, self.patches)) + m.bias_function.append(LowVramPatch(bias_key, self.patches, m)) patch_counter += 1 cast_weight = True @@ -1192,7 +1226,6 @@ def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_pat if used: target_device = weight.device self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) - # TODO: properly handle LowVramPatch, if it ends up an issue temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) diff --git a/comfy/ops.py b/comfy/ops.py index 9d7dedd374b6..fca32a4e2c90 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -415,6 +415,13 @@ def convert_weight(self, weight, inplace=False, **kwargs): return weight else: return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) + + def revert_weight(self, weight, inplace=False, **kwargs): + if inplace: + weight /= self.scale_weight.to(device=weight.device, dtype=weight.dtype) + return weight + else: + return weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype) def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)