Skip to content

Commit

Permalink
Fix SD15 memory sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Jan 1, 2025
1 parent ab409ee commit 422ae99
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 24 deletions.
4 changes: 1 addition & 3 deletions layered_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,9 @@ def apply_layered_diffusion_attn_sharing(
layer_lora_state_dict = load_layer_model_state_dict(model_path)
work_model = model.clone()
patcher = AttentionSharingPatcher(
work_model, self.frames, use_control=control_img is not None
work_model, self.frames, control_img=control_img
)
patcher.load_state_dict(layer_lora_state_dict, strict=True)
if control_img is not None:
patcher.set_control(control_img)
return (work_model,)


Expand Down
34 changes: 13 additions & 21 deletions lib_layerdiffusion/attention_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class AttentionSharingUnit(torch.nn.Module):
# call.
transformer_options: dict = {}

def __init__(self, module, frames=2, use_control=True, rank=256):
def __init__(self, module, frames=2, control_signals=None, rank=256):
super().__init__()

self.heads = module.heads
Expand Down Expand Up @@ -142,9 +142,9 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
in_features=hidden_size, out_features=hidden_size
)

self.control_signals = control_signals
self.control_convs = None

if use_control:
if control_signals is not None:
self.control_convs = [
torch.nn.Sequential(
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
Expand All @@ -155,7 +155,6 @@ def __init__(self, module, frames=2, use_control=True, rank=256):
]
self.control_convs = torch.nn.ModuleList(self.control_convs)

self.control_signals = None

def forward(self, h, context=None, value=None):
transformer_options = self.transformer_options
Expand Down Expand Up @@ -325,36 +324,29 @@ def __init__(self, layer_list):


class AttentionSharingPatcher(torch.nn.Module):
def __init__(self, unet, frames=2, use_control=True, rank=256):
def __init__(self, unet, frames=2, control_img=None, rank=256):
super().__init__()
model_management.unload_model_clones(unet)
control_signals = (
AdditionalAttentionCondsEncoder()(control_img.cpu().float() * 2.0 - 1.0)
if control_img is not None
else None
)

units = []
for i in range(32):
real_key = module_mapping_sd15[i]
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
u = AttentionSharingUnit(
attn_module, frames=frames, use_control=use_control, rank=rank
attn_module,
frames=frames,
control_signals=control_signals,
rank=rank,
)
units.append(u)
unet.add_object_patch("diffusion_model." + real_key, u)

self.hookers = HookerLayers(units)

if use_control:
self.kwargs_encoder = AdditionalAttentionCondsEncoder()
else:
self.kwargs_encoder = None

self.dtype = torch.float32
if model_management.should_use_fp16(model_management.get_torch_device()):
self.dtype = torch.float16
self.hookers.half()
return

def set_control(self, img):
img = img.cpu().float() * 2.0 - 1.0
signals = self.kwargs_encoder(img)
for m in self.hookers.layers:
m.control_signals = signals
return

0 comments on commit 422ae99

Please sign in to comment.