-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Put the model file in models/photomaker and use PhotoMakerLoader. Then use PhotoMakerEncode with the keyword "photomaker" to apply the image
- Loading branch information
1 parent
b9911dc
commit d1533d9
Showing
4 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import torch | ||
import torch.nn as nn | ||
import folder_paths | ||
import comfy.clip_model | ||
import comfy.clip_vision | ||
import comfy.ops | ||
|
||
# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 | ||
VISION_CONFIG_DICT = { | ||
"hidden_size": 1024, | ||
"image_size": 224, | ||
"intermediate_size": 4096, | ||
"num_attention_heads": 16, | ||
"num_channels": 3, | ||
"num_hidden_layers": 24, | ||
"patch_size": 14, | ||
"projection_dim": 768, | ||
"hidden_act": "quick_gelu", | ||
} | ||
|
||
class MLP(nn.Module): | ||
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops): | ||
super().__init__() | ||
if use_residual: | ||
assert in_dim == out_dim | ||
self.layernorm = operations.LayerNorm(in_dim) | ||
self.fc1 = operations.Linear(in_dim, hidden_dim) | ||
self.fc2 = operations.Linear(hidden_dim, out_dim) | ||
self.use_residual = use_residual | ||
self.act_fn = nn.GELU() | ||
|
||
def forward(self, x): | ||
residual = x | ||
x = self.layernorm(x) | ||
x = self.fc1(x) | ||
x = self.act_fn(x) | ||
x = self.fc2(x) | ||
if self.use_residual: | ||
x = x + residual | ||
return x | ||
|
||
|
||
class FuseModule(nn.Module): | ||
def __init__(self, embed_dim, operations): | ||
super().__init__() | ||
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations) | ||
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations) | ||
self.layer_norm = operations.LayerNorm(embed_dim) | ||
|
||
def fuse_fn(self, prompt_embeds, id_embeds): | ||
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) | ||
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds | ||
stacked_id_embeds = self.mlp2(stacked_id_embeds) | ||
stacked_id_embeds = self.layer_norm(stacked_id_embeds) | ||
return stacked_id_embeds | ||
|
||
def forward( | ||
self, | ||
prompt_embeds, | ||
id_embeds, | ||
class_tokens_mask, | ||
) -> torch.Tensor: | ||
# id_embeds shape: [b, max_num_inputs, 1, 2048] | ||
id_embeds = id_embeds.to(prompt_embeds.dtype) | ||
num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case | ||
batch_size, max_num_inputs = id_embeds.shape[:2] | ||
# seq_length: 77 | ||
seq_length = prompt_embeds.shape[1] | ||
# flat_id_embeds shape: [b*max_num_inputs, 1, 2048] | ||
flat_id_embeds = id_embeds.view( | ||
-1, id_embeds.shape[-2], id_embeds.shape[-1] | ||
) | ||
# valid_id_mask [b*max_num_inputs] | ||
valid_id_mask = ( | ||
torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] | ||
< num_inputs[:, None] | ||
) | ||
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] | ||
|
||
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) | ||
class_tokens_mask = class_tokens_mask.view(-1) | ||
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) | ||
# slice out the image token embeddings | ||
image_token_embeds = prompt_embeds[class_tokens_mask] | ||
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) | ||
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" | ||
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) | ||
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) | ||
return updated_prompt_embeds | ||
|
||
class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): | ||
def __init__(self): | ||
self.load_device = comfy.model_management.text_encoder_device() | ||
offload_device = comfy.model_management.text_encoder_offload_device() | ||
dtype = comfy.model_management.text_encoder_dtype(self.load_device) | ||
|
||
super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast) | ||
self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False) | ||
self.fuse_module = FuseModule(2048, comfy.ops.manual_cast) | ||
|
||
def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): | ||
b, num_inputs, c, h, w = id_pixel_values.shape | ||
id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) | ||
|
||
shared_id_embeds = self.vision_model(id_pixel_values)[2] | ||
id_embeds = self.visual_projection(shared_id_embeds) | ||
id_embeds_2 = self.visual_projection_2(shared_id_embeds) | ||
|
||
id_embeds = id_embeds.view(b, num_inputs, 1, -1) | ||
id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) | ||
|
||
id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) | ||
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) | ||
|
||
return updated_prompt_embeds | ||
|
||
|
||
class PhotoMakerLoader: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} | ||
|
||
RETURN_TYPES = ("PHOTOMAKER",) | ||
FUNCTION = "load_photomaker_model" | ||
|
||
CATEGORY = "_for_testing/photomaker" | ||
|
||
def load_photomaker_model(self, photomaker_model_name): | ||
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) | ||
photomaker_model = PhotoMakerIDEncoder() | ||
data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) | ||
if "id_encoder" in data: | ||
data = data["id_encoder"] | ||
photomaker_model.load_state_dict(data) | ||
return (photomaker_model,) | ||
|
||
|
||
class PhotoMakerEncode: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "photomaker": ("PHOTOMAKER",), | ||
"image": ("IMAGE",), | ||
"clip": ("CLIP", ), | ||
"text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), | ||
}} | ||
|
||
RETURN_TYPES = ("CONDITIONING",) | ||
FUNCTION = "apply_photomaker" | ||
|
||
CATEGORY = "_for_testing/photomaker" | ||
|
||
def apply_photomaker(self, photomaker, image, clip, text): | ||
special_token = "photomaker" | ||
pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() | ||
try: | ||
index = text.split(" ").index(special_token) + 1 | ||
except ValueError: | ||
index = -1 | ||
tokens = clip.tokenize(text, return_word_ids=True) | ||
out_tokens = {} | ||
for k in tokens: | ||
out_tokens[k] = [] | ||
for t in tokens[k]: | ||
f = list(filter(lambda x: x[2] != index, t)) | ||
while len(f) < len(t): | ||
f.append(t[-1]) | ||
out_tokens[k].append(f) | ||
|
||
cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) | ||
|
||
if index > 0: | ||
token_index = index - 1 | ||
num_id_images = 1 | ||
class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)] | ||
out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), | ||
class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) | ||
else: | ||
out = cond | ||
|
||
return ([[out, {"pooled_output": pooled}]], ) | ||
|
||
|
||
NODE_CLASS_MAPPINGS = { | ||
"PhotoMakerLoader": PhotoMakerLoader, | ||
"PhotoMakerEncode": PhotoMakerEncode, | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
d1533d9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you be updating this for PhotoMaker v2? 😁
d1533d9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not working properly, puting photomaker-v2.bin in models/photomaker, restarted comfyui, photomaker loader plus node does not detect the model. Plz provide fix for this.