diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482e9f..001abd8432f9 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + PinnedMem = "pinned_memory" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) diff --git a/comfy/model_management.py b/comfy/model_management.py index afe78f36ec0e..3e5b977d4375 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1080,6 +1080,36 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) +def pin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0: + return True + + return False + +def unpin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0: + return True + + return False + def sage_attention_enabled(): return args.use_sage_attention diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8cff7..aec73349c8e5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -238,6 +238,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None + self.pinned = set() self.attachments: dict[str] = {} self.additional_models: dict[str, list[ModelPatcher]] = {} @@ -618,6 +619,21 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False): else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def pin_weight_to_device(self, key): + weight, set_func, convert_func = get_key_weight(self.model, key) + if comfy.model_management.pin_memory(weight): + self.pinned.add(key) + + def unpin_weight(self, key): + if key in self.pinned: + weight, set_func, convert_func = get_key_weight(self.model, key) + comfy.model_management.unpin_memory(weight) + self.pinned.remove(key) + + def unpin_all_weights(self): + for key in list(self.pinned): + self.unpin_weight(key) + def _load_list(self): loading = [] for n, m in self.model.named_modules(): @@ -683,6 +699,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False patch_counter += 1 cast_weight = True + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -713,7 +731,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + key = "{}.{}".format(n, param) + self.unpin_weight(key) + self.patch_weight_to_device(key, device_to=device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True @@ -762,6 +782,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.eject_model() if unpatch_weights: self.unpatch_hooks() + self.unpin_all_weights() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) @@ -857,6 +878,9 @@ def partially_unload(self, device_to, memory_to_free=0): memory_freed += module_mem logging.debug("freed {}".format(n)) + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed