Skip to content

Commit 5d8b198

Browse files
Add unload_textual_inversion method (#6656)
* Add unload_textual_inversion * Fix dicts in tokenizer * Fix quality * Apply suggestions from code review Co-authored-by: YiYi Xu <[email protected]> * Fix variable name after last update --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent acd1962 commit 5d8b198

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

src/diffusers/loaders/textual_inversion.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,91 @@ def load_textual_inversion(
453453
self.enable_sequential_cpu_offload()
454454

455455
# / Unsafe Code >
456+
457+
def unload_textual_inversion(
458+
self,
459+
tokens: Optional[Union[str, List[str]]] = None,
460+
):
461+
r"""
462+
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
463+
464+
Example:
465+
```py
466+
from diffusers import AutoPipelineForText2Image
467+
import torch
468+
469+
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
470+
471+
# Example 1
472+
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
473+
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
474+
475+
# Remove all token embeddings
476+
pipeline.unload_textual_inversion()
477+
478+
# Example 2
479+
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
480+
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
481+
482+
# Remove just one token
483+
pipeline.unload_textual_inversion("<moe-bius>")
484+
```
485+
"""
486+
487+
tokenizer = getattr(self, "tokenizer", None)
488+
text_encoder = getattr(self, "text_encoder", None)
489+
490+
# Get textual inversion tokens and ids
491+
token_ids = []
492+
last_special_token_id = None
493+
494+
if tokens:
495+
if isinstance(tokens, str):
496+
tokens = [tokens]
497+
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
498+
if not added_token.special:
499+
if added_token.content in tokens:
500+
token_ids.append(added_token_id)
501+
else:
502+
last_special_token_id = added_token_id
503+
if len(token_ids) == 0:
504+
raise ValueError("No tokens to remove found")
505+
else:
506+
tokens = []
507+
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
508+
if not added_token.special:
509+
token_ids.append(added_token_id)
510+
tokens.append(added_token.content)
511+
else:
512+
last_special_token_id = added_token_id
513+
514+
# Delete from tokenizer
515+
for token_id, token_to_remove in zip(token_ids, tokens):
516+
del tokenizer._added_tokens_decoder[token_id]
517+
del tokenizer._added_tokens_encoder[token_to_remove]
518+
519+
# Make all token ids sequential in tokenizer
520+
key_id = 1
521+
for token_id in tokenizer.added_tokens_decoder:
522+
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
523+
token = tokenizer._added_tokens_decoder[token_id]
524+
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
525+
del tokenizer._added_tokens_decoder[token_id]
526+
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
527+
key_id += 1
528+
tokenizer._update_trie()
529+
530+
# Delete from text encoder
531+
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
532+
temp_text_embedding_weights = text_encoder.get_input_embeddings().weight
533+
text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1]
534+
to_append = []
535+
for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]):
536+
if i not in token_ids:
537+
to_append.append(temp_text_embedding_weights[i].unsqueeze(0))
538+
if len(to_append) > 0:
539+
to_append = torch.cat(to_append, dim=0)
540+
text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0)
541+
text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim)
542+
text_embeddings_filtered.weight.data = text_embedding_weights
543+
text_encoder.set_input_embeddings(text_embeddings_filtered)

0 commit comments

Comments
 (0)