Skip to content

Commit 27e067c

Browse files
Implement the USO subject identity lora. (#9674)
Use the lora with FluxContextMultiReferenceLatentMethod node set to "uso" and a ReferenceLatent node with the reference image.
1 parent 9b15155 commit 27e067c

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

comfy/ldm/flux/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,18 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
233233
h = 0
234234
w = 0
235235
index = 0
236-
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
236+
ref_latents_method = kwargs.get("ref_latents_method", "offset")
237237
for ref in ref_latents:
238-
if index_ref_method:
238+
if ref_latents_method == "index":
239239
index += 1
240240
h_offset = 0
241241
w_offset = 0
242+
elif ref_latents_method == "uso":
243+
index = 0
244+
h_offset = h_len * patch_size + h
245+
w_offset = w_len * patch_size + w
246+
h += ref.shape[-2]
247+
w += ref.shape[-1]
242248
else:
243249
index = 1
244250
h_offset = 0

comfy/lora.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
260260
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
261261
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
262262
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
263+
for k in sdk:
264+
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
265+
if k.endswith(".weight") and ".linear1." in k:
266+
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
263267

264268
if isinstance(model, comfy.model_base.GenmoMochi):
265269
for k in sdk:

comfy/lora_convert.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
1515
def convert_lora_wan_fun(sd): #Wan Fun loras
1616
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
1717

18+
def convert_uso_lora(sd):
19+
sd_out = {}
20+
for k in sd:
21+
tensor = sd[k]
22+
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
23+
.replace(".up.weight", ".lora_up.weight")
24+
.replace(".qkv_lora2.", ".txt_attn.qkv.")
25+
.replace(".qkv_lora1.", ".img_attn.qkv.")
26+
.replace(".proj_lora1.", ".img_attn.proj.")
27+
.replace(".proj_lora2.", ".txt_attn.proj.")
28+
.replace(".qkv_lora.", ".linear1_qkv.")
29+
.replace(".proj_lora.", ".linear2.")
30+
.replace(".processor.", ".")
31+
)
32+
sd_out[k_to] = tensor
33+
return sd_out
34+
1835

1936
def convert_lora(sd):
2037
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
2138
return convert_lora_bfl_control(sd)
2239
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
2340
return convert_lora_wan_fun(sd)
41+
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
42+
return convert_uso_lora(sd)
2443
return sd

comfy_extras/nodes_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
105105
def INPUT_TYPES(s):
106106
return {"required": {
107107
"conditioning": ("CONDITIONING", ),
108-
"reference_latents_method": (("offset", "index"), ),
108+
"reference_latents_method": (("offset", "index", "uso"), ),
109109
}}
110110

111111
RETURN_TYPES = ("CONDITIONING",)

0 commit comments

Comments
 (0)