@@ -453,3 +453,91 @@ def load_textual_inversion(
453453 self .enable_sequential_cpu_offload ()
454454
455455 # / Unsafe Code >
456+
457+ def unload_textual_inversion (
458+ self ,
459+ tokens : Optional [Union [str , List [str ]]] = None ,
460+ ):
461+ r"""
462+ Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
463+
464+ Example:
465+ ```py
466+ from diffusers import AutoPipelineForText2Image
467+ import torch
468+
469+ pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
470+
471+ # Example 1
472+ pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
473+ pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
474+
475+ # Remove all token embeddings
476+ pipeline.unload_textual_inversion()
477+
478+ # Example 2
479+ pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
480+ pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
481+
482+ # Remove just one token
483+ pipeline.unload_textual_inversion("<moe-bius>")
484+ ```
485+ """
486+
487+ tokenizer = getattr (self , "tokenizer" , None )
488+ text_encoder = getattr (self , "text_encoder" , None )
489+
490+ # Get textual inversion tokens and ids
491+ token_ids = []
492+ last_special_token_id = None
493+
494+ if tokens :
495+ if isinstance (tokens , str ):
496+ tokens = [tokens ]
497+ for added_token_id , added_token in tokenizer .added_tokens_decoder .items ():
498+ if not added_token .special :
499+ if added_token .content in tokens :
500+ token_ids .append (added_token_id )
501+ else :
502+ last_special_token_id = added_token_id
503+ if len (token_ids ) == 0 :
504+ raise ValueError ("No tokens to remove found" )
505+ else :
506+ tokens = []
507+ for added_token_id , added_token in tokenizer .added_tokens_decoder .items ():
508+ if not added_token .special :
509+ token_ids .append (added_token_id )
510+ tokens .append (added_token .content )
511+ else :
512+ last_special_token_id = added_token_id
513+
514+ # Delete from tokenizer
515+ for token_id , token_to_remove in zip (token_ids , tokens ):
516+ del tokenizer ._added_tokens_decoder [token_id ]
517+ del tokenizer ._added_tokens_encoder [token_to_remove ]
518+
519+ # Make all token ids sequential in tokenizer
520+ key_id = 1
521+ for token_id in tokenizer .added_tokens_decoder :
522+ if token_id > last_special_token_id and token_id > last_special_token_id + key_id :
523+ token = tokenizer ._added_tokens_decoder [token_id ]
524+ tokenizer ._added_tokens_decoder [last_special_token_id + key_id ] = token
525+ del tokenizer ._added_tokens_decoder [token_id ]
526+ tokenizer ._added_tokens_encoder [token .content ] = last_special_token_id + key_id
527+ key_id += 1
528+ tokenizer ._update_trie ()
529+
530+ # Delete from text encoder
531+ text_embedding_dim = text_encoder .get_input_embeddings ().embedding_dim
532+ temp_text_embedding_weights = text_encoder .get_input_embeddings ().weight
533+ text_embedding_weights = temp_text_embedding_weights [: last_special_token_id + 1 ]
534+ to_append = []
535+ for i in range (last_special_token_id + 1 , temp_text_embedding_weights .shape [0 ]):
536+ if i not in token_ids :
537+ to_append .append (temp_text_embedding_weights [i ].unsqueeze (0 ))
538+ if len (to_append ) > 0 :
539+ to_append = torch .cat (to_append , dim = 0 )
540+ text_embedding_weights = torch .cat ([text_embedding_weights , to_append ], dim = 0 )
541+ text_embeddings_filtered = nn .Embedding (text_embedding_weights .shape [0 ], text_embedding_dim )
542+ text_embeddings_filtered .weight .data = text_embedding_weights
543+ text_encoder .set_input_embeddings (text_embeddings_filtered )
0 commit comments