@@ -843,7 +843,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
843843
844844 self .object_patches_backup .clear ()
845845
846- def partially_unload (self , device_to , memory_to_free = 0 ):
846+ def partially_unload (self , device_to , memory_to_free = 0 , force_patch_weights = False ):
847847 with self .use_ejected ():
848848 hooks_unpatched = False
849849 memory_freed = 0
@@ -887,13 +887,19 @@ def partially_unload(self, device_to, memory_to_free=0):
887887 module_mem += move_weight_functions (m , device_to )
888888 if lowvram_possible :
889889 if weight_key in self .patches :
890- _ , set_func , convert_func = get_key_weight (self .model , weight_key )
891- m .weight_function .append (LowVramPatch (weight_key , self .patches , convert_func , set_func ))
892- patch_counter += 1
890+ if force_patch_weights :
891+ self .patch_weight_to_device (weight_key )
892+ else :
893+ _ , set_func , convert_func = get_key_weight (self .model , weight_key )
894+ m .weight_function .append (LowVramPatch (weight_key , self .patches , convert_func , set_func ))
895+ patch_counter += 1
893896 if bias_key in self .patches :
894- _ , set_func , convert_func = get_key_weight (self .model , bias_key )
895- m .bias_function .append (LowVramPatch (bias_key , self .patches , convert_func , set_func ))
896- patch_counter += 1
897+ if force_patch_weights :
898+ self .patch_weight_to_device (bias_key )
899+ else :
900+ _ , set_func , convert_func = get_key_weight (self .model , bias_key )
901+ m .bias_function .append (LowVramPatch (bias_key , self .patches , convert_func , set_func ))
902+ patch_counter += 1
897903 cast_weight = True
898904
899905 if cast_weight :
0 commit comments