diff --git a/nodes.py b/nodes.py index 08849df..ff5aaf0 100644 --- a/nodes.py +++ b/nodes.py @@ -79,13 +79,14 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): def pin_weight_to_device(self, key): op_key = key.rsplit('.', 1)[0] - if self.named_modules_to_munmap is not None and op_key in self.named_modules_to_munmap: + if not self.mmap_released and op_key in self.named_modules_to_munmap: # TODO: possible to OOM, find better way to detach self.named_modules_to_munmap[op_key].to(self.load_device).to(self.offload_device) del self.named_modules_to_munmap[op_key] super().pin_weight_to_device(key) mmap_released = False + named_modules_to_munmap = {} def load(self, *args, force_patch_weights=False, **kwargs): if not self.mmap_released: @@ -115,7 +116,7 @@ def load(self, *args, force_patch_weights=False, **kwargs): # TODO: possible to OOM, find better way to detach m.to(self.load_device).to(self.offload_device) self.mmap_released = True - self.named_modules_to_munmap = None + self.named_modules_to_munmap = {} def clone(self, *args, **kwargs): src_cls = self.__class__ @@ -125,6 +126,7 @@ def clone(self, *args, **kwargs): self.__class__ = src_cls # GGUF specific clone values below n.patch_on_device = getattr(self, "patch_on_device", False) + n.mmap_released = getattr(self, "mmap_released", False) if src_cls != GGUFModelPatcher: n.size = 0 # force recalc return n