Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAM optimization round 2 #12599

Merged
merged 5 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)

weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None:
if weights_backup is None and wanted_names != ():
if current_names != ():
raise RuntimeError("no backup weights found and current weights are not unchanged")

if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
Expand Down
63 changes: 53 additions & 10 deletions modules/sd_disable_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper):
```
"""

def __init__(self, state_dict, device):
def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__()
self.state_dict = state_dict
self.device = device
self.weight_dtype_conversion = weight_dtype_conversion or {}
self.default_dtype = self.weight_dtype_conversion.get('')

def get_weight_dtype(self, key):
key_first_term, _ = key.split('.', 1)
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)

def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
Expand All @@ -167,23 +173,60 @@ def __enter__(self):
sd = self.state_dict
device = self.device

def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []

for name, param in params:
if param.is_meta:
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
for name, param in module._parameters.items():
if param is None:
continue

original(self, state_dict, prefix, *args, **kwargs)
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key)

for name, _ in params:
if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)

for name in module._buffers:
key = prefix + name
if key in sd:
del sd[key]

sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param
used_param_keys.append(key)

original(module, state_dict, prefix, *args, **kwargs)

for key in used_param_keys:
state_dict.pop(key, None)

def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.

In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).

The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
the function and does not call the original) the state dict will just fail to load because weights
would be on the meta device.
"""

if state_dict == sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}

original(module, state_dict, strict=strict)

module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))

def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
22 changes: 20 additions & 2 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")

if not shared.cmd_opts.no_half:
if shared.cmd_opts.no_half:
model.float()
timer.record("apply float()")
else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)

Expand Down Expand Up @@ -518,6 +521,13 @@ def send_model_to_cpu(m):
devices.torch_gc()


def model_target_device():
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
return devices.cpu
else:
return devices.device


def send_model_to_device(m):
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
Expand Down Expand Up @@ -579,7 +589,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):

timer.record("create model")

with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
if shared.cmd_opts.no_half:
weight_dtype_conversion = None
else:
weight_dtype_conversion = {
'first_stage_model': None,
'': torch.float16,
}

with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")

Expand Down