-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add unload_textual_inversion method #6656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -453,3 +453,91 @@ def load_textual_inversion( | |
| self.enable_sequential_cpu_offload() | ||
|
|
||
| # / Unsafe Code > | ||
|
|
||
| def unload_textual_inversion( | ||
| self, | ||
| tokens: Optional[Union[str, List[str]]] = None, | ||
| ): | ||
| r""" | ||
| Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`] | ||
|
|
||
| Example: | ||
| ```py | ||
| from diffusers import AutoPipelineForText2Image | ||
| import torch | ||
|
|
||
| pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5") | ||
|
|
||
| # Example 1 | ||
| pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork") | ||
| pipeline.load_textual_inversion("sd-concepts-library/moeb-style") | ||
|
|
||
| # Remove all token embeddings | ||
| pipeline.unload_textual_inversion() | ||
|
|
||
| # Example 2 | ||
| pipeline.load_textual_inversion("sd-concepts-library/moeb-style") | ||
| pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork") | ||
|
|
||
| # Remove just one token | ||
| pipeline.unload_textual_inversion("<moe-bius>") | ||
| ``` | ||
| """ | ||
|
|
||
| tokenizer = getattr(self, "tokenizer", None) | ||
| text_encoder = getattr(self, "text_encoder", None) | ||
|
|
||
| # Get textual inversion tokens and ids | ||
| token_ids = [] | ||
| last_special_token_id = None | ||
|
|
||
| if tokens: | ||
| if isinstance(tokens, str): | ||
| tokens = [tokens] | ||
| for token_id, added_token in tokenizer.added_tokens_decoder.items(): | ||
| if not added_token.special: | ||
| if added_token.content in tokens: | ||
| token_ids.append(token_id) | ||
fabiorigano marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else: | ||
| last_special_token_id = token_id | ||
| if len(token_ids) == 0: | ||
| raise ValueError("No tokens to remove found") | ||
| else: | ||
| tokens = [] | ||
| for token_id, added_token in tokenizer.added_tokens_decoder.items(): | ||
fabiorigano marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not added_token.special: | ||
| token_ids.append(token_id) | ||
fabiorigano marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| tokens.append(added_token.content) | ||
| else: | ||
| last_special_token_id = token_id | ||
fabiorigano marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Delete from tokenizer | ||
| for token_id, token_to_remove in zip(token_ids, tokens): | ||
| del tokenizer._added_tokens_decoder[token_id] | ||
| del tokenizer._added_tokens_encoder[token_to_remove] | ||
|
|
||
| # Fix token ids in tokenizer | ||
| key_id = 1 | ||
| for token_id in tokenizer.added_tokens_decoder: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain why do we need this block?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this block to make all token ids sequential after one of the added tokens is removed
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for explaining this! @fabiorigano I'm not very familiar with the use case
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @fabiorigano! I looked a bit but I'm actually not quite sure why it's necessary to reorder 🤔
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi @linoytsaban, the reordering block is useful to have the same indeces as the text embeddings in the encoder, so multiple
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gotcha! that makes total sense, thanks for explaining! 🤗 |
||
| if token_id > last_special_token_id and token_id > last_special_token_id + key_id: | ||
| token = tokenizer._added_tokens_decoder[token_id] | ||
| tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token | ||
| del tokenizer._added_tokens_decoder[token_id] | ||
| tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id | ||
| key_id += 1 | ||
| tokenizer._update_trie() | ||
|
|
||
| # Delete from text encoder | ||
| text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim | ||
| temp_text_embedding_weights = text_encoder.get_input_embeddings().weight | ||
| text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1] | ||
| to_append = [] | ||
| for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]): | ||
| if i not in token_ids: | ||
| to_append.append(temp_text_embedding_weights[i].unsqueeze(0)) | ||
| if len(to_append) > 0: | ||
| to_append = torch.cat(to_append, dim=0) | ||
| text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0) | ||
| text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim) | ||
| text_embeddings_filtered.weight.data = text_embedding_weights | ||
| text_encoder.set_input_embeddings(text_embeddings_filtered) | ||
Uh oh!
There was an error while loading. Please reload this page.