-
Notifications
You must be signed in to change notification settings - Fork 27.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
326 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import copy | ||
import itertools | ||
import os | ||
from pathlib import Path | ||
import html | ||
import gc | ||
|
||
import gradio as gr | ||
import torch | ||
from PIL import Image | ||
from torch import optim | ||
|
||
from modules import shared | ||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer | ||
from tqdm.auto import tqdm, trange | ||
from modules.shared import opts, device | ||
|
||
|
||
def get_all_images_in_folder(folder): | ||
return [os.path.join(folder, f) for f in os.listdir(folder) if | ||
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] | ||
|
||
|
||
def check_is_valid_image_file(filename): | ||
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) | ||
|
||
|
||
def batched(dataset, total, n=1): | ||
for ndx in range(0, total, n): | ||
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] | ||
|
||
|
||
def iter_to_batched(iterable, n=1): | ||
it = iter(iterable) | ||
while True: | ||
chunk = tuple(itertools.islice(it, n)) | ||
if not chunk: | ||
return | ||
yield chunk | ||
|
||
|
||
def create_ui(): | ||
with gr.Group(): | ||
with gr.Accordion("Open for Clip Aesthetic!", open=False): | ||
with gr.Row(): | ||
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", | ||
value=0.9) | ||
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) | ||
|
||
with gr.Row(): | ||
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', | ||
placeholder="Aesthetic learning rate", value="0.0001") | ||
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) | ||
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), | ||
label="Aesthetic imgs embedding", | ||
value="None") | ||
|
||
with gr.Row(): | ||
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', | ||
placeholder="This text is used to rotate the feature space of the imgs embs", | ||
value="") | ||
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, | ||
value=0.1) | ||
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) | ||
|
||
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative | ||
|
||
|
||
def generate_imgs_embd(name, folder, batch_size): | ||
# clipModel = CLIPModel.from_pretrained( | ||
# shared.sd_model.cond_stage_model.clipModel.name_or_path | ||
# ) | ||
model = shared.clip_model.to(device) | ||
processor = CLIPProcessor.from_pretrained(model.name_or_path) | ||
|
||
with torch.no_grad(): | ||
embs = [] | ||
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), | ||
desc=f"Generating embeddings for {name}"): | ||
if shared.state.interrupted: | ||
break | ||
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) | ||
outputs = model.get_image_features(**inputs).cpu() | ||
embs.append(torch.clone(outputs)) | ||
inputs.to("cpu") | ||
del inputs, outputs | ||
|
||
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) | ||
|
||
# The generated embedding will be located here | ||
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") | ||
torch.save(embs, path) | ||
|
||
model = model.cpu() | ||
del processor | ||
del embs | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
res = f""" | ||
Done generating embedding for {name}! | ||
Aesthetic embedding saved to {html.escape(path)} | ||
""" | ||
shared.update_aesthetic_embeddings() | ||
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", | ||
value="None"), \ | ||
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), | ||
label="Imgs embedding", | ||
value="None"), res, "" | ||
|
||
|
||
def slerp(low, high, val): | ||
low_norm = low / torch.norm(low, dim=1, keepdim=True) | ||
high_norm = high / torch.norm(high, dim=1, keepdim=True) | ||
omega = torch.acos((low_norm * high_norm).sum(1)) | ||
so = torch.sin(omega) | ||
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high | ||
return res | ||
|
||
|
||
class AestheticCLIP: | ||
def __init__(self): | ||
self.skip = False | ||
self.aesthetic_steps = 0 | ||
self.aesthetic_weight = 0 | ||
self.aesthetic_lr = 0 | ||
self.slerp = False | ||
self.aesthetic_text_negative = "" | ||
self.aesthetic_slerp_angle = 0 | ||
self.aesthetic_imgs_text = "" | ||
|
||
self.image_embs_name = None | ||
self.image_embs = None | ||
self.load_image_embs(None) | ||
|
||
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, | ||
aesthetic_slerp=True, aesthetic_imgs_text="", | ||
aesthetic_slerp_angle=0.15, | ||
aesthetic_text_negative=False): | ||
self.aesthetic_imgs_text = aesthetic_imgs_text | ||
self.aesthetic_slerp_angle = aesthetic_slerp_angle | ||
self.aesthetic_text_negative = aesthetic_text_negative | ||
self.slerp = aesthetic_slerp | ||
self.aesthetic_lr = aesthetic_lr | ||
self.aesthetic_weight = aesthetic_weight | ||
self.aesthetic_steps = aesthetic_steps | ||
self.load_image_embs(image_embs_name) | ||
|
||
def set_skip(self, skip): | ||
self.skip = skip | ||
|
||
def load_image_embs(self, image_embs_name): | ||
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": | ||
image_embs_name = None | ||
self.image_embs_name = None | ||
if image_embs_name is not None and self.image_embs_name != image_embs_name: | ||
self.image_embs_name = image_embs_name | ||
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) | ||
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) | ||
self.image_embs.requires_grad_(False) | ||
|
||
def __call__(self, z, remade_batch_tokens): | ||
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: | ||
tokenizer = shared.sd_model.cond_stage_model.tokenizer | ||
if not opts.use_old_emphasis_implementation: | ||
remade_batch_tokens = [ | ||
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in | ||
remade_batch_tokens] | ||
|
||
tokens = torch.asarray(remade_batch_tokens).to(device) | ||
|
||
model = copy.deepcopy(shared.clip_model).to(device) | ||
model.requires_grad_(True) | ||
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: | ||
text_embs_2 = model.get_text_features( | ||
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) | ||
if self.aesthetic_text_negative: | ||
text_embs_2 = self.image_embs - text_embs_2 | ||
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) | ||
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) | ||
else: | ||
img_embs = self.image_embs | ||
|
||
with torch.enable_grad(): | ||
|
||
# We optimize the model to maximize the similarity | ||
optimizer = optim.Adam( | ||
model.text_model.parameters(), lr=self.aesthetic_lr | ||
) | ||
|
||
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): | ||
text_embs = model.get_text_features(input_ids=tokens) | ||
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) | ||
sim = text_embs @ img_embs.T | ||
loss = -sim | ||
optimizer.zero_grad() | ||
loss.mean().backward() | ||
optimizer.step() | ||
|
||
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) | ||
if opts.CLIP_stop_at_last_layers > 1: | ||
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] | ||
zn = model.text_model.final_layer_norm(zn) | ||
else: | ||
zn = zn.last_hidden_state | ||
model.cpu() | ||
del model | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) | ||
if self.slerp: | ||
z = slerp(z, zn, self.aesthetic_weight) | ||
else: | ||
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight | ||
|
||
return z |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.