|
| 1 | +# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. |
| 2 | +import argparse |
| 3 | + |
| 4 | +import accelerate |
| 5 | +import torch |
| 6 | +from safetensors.torch import load_file |
| 7 | +from transformers import ( |
| 8 | + AutoTokenizer, |
| 9 | + CLIPConfig, |
| 10 | + CLIPImageProcessor, |
| 11 | + CLIPTextModelWithProjection, |
| 12 | + CLIPVisionModelWithProjection, |
| 13 | +) |
| 14 | + |
| 15 | +from diffusers import ( |
| 16 | + DDPMWuerstchenScheduler, |
| 17 | + StableCascadeCombinedPipeline, |
| 18 | + StableCascadeDecoderPipeline, |
| 19 | + StableCascadePriorPipeline, |
| 20 | +) |
| 21 | +from diffusers.models import StableCascadeUNet |
| 22 | +from diffusers.models.modeling_utils import load_model_dict_into_meta |
| 23 | +from diffusers.pipelines.wuerstchen import PaellaVQModel |
| 24 | + |
| 25 | + |
| 26 | +parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline") |
| 27 | +parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights") |
| 28 | +parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file") |
| 29 | +parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file") |
| 30 | +parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") |
| 31 | +parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to") |
| 32 | +parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") |
| 33 | + |
| 34 | +args = parser.parse_args() |
| 35 | +model_path = args.model_path |
| 36 | + |
| 37 | +device = "cpu" |
| 38 | + |
| 39 | +# set paths to model weights |
| 40 | +prior_checkpoint_path = f"{model_path}/{args.stage_c_name}" |
| 41 | +decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}" |
| 42 | + |
| 43 | +# Clip Text encoder and tokenizer |
| 44 | +config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
| 45 | +config.text_config.projection_dim = config.projection_dim |
| 46 | +text_encoder = CLIPTextModelWithProjection.from_pretrained( |
| 47 | + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config |
| 48 | +) |
| 49 | +tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
| 50 | + |
| 51 | +# image processor |
| 52 | +feature_extractor = CLIPImageProcessor() |
| 53 | +image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") |
| 54 | + |
| 55 | +# Prior |
| 56 | +if args.use_safetensors: |
| 57 | + orig_state_dict = load_file(prior_checkpoint_path, device=device) |
| 58 | +else: |
| 59 | + orig_state_dict = torch.load(prior_checkpoint_path, map_location=device) |
| 60 | + |
| 61 | +state_dict = {} |
| 62 | +for key in orig_state_dict.keys(): |
| 63 | + if key.endswith("in_proj_weight"): |
| 64 | + weights = orig_state_dict[key].chunk(3, 0) |
| 65 | + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] |
| 66 | + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] |
| 67 | + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] |
| 68 | + elif key.endswith("in_proj_bias"): |
| 69 | + weights = orig_state_dict[key].chunk(3, 0) |
| 70 | + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] |
| 71 | + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] |
| 72 | + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] |
| 73 | + elif key.endswith("out_proj.weight"): |
| 74 | + weights = orig_state_dict[key] |
| 75 | + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights |
| 76 | + elif key.endswith("out_proj.bias"): |
| 77 | + weights = orig_state_dict[key] |
| 78 | + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights |
| 79 | + else: |
| 80 | + state_dict[key] = orig_state_dict[key] |
| 81 | + |
| 82 | + |
| 83 | +with accelerate.init_empty_weights(): |
| 84 | + prior_model = StableCascadeUNet( |
| 85 | + in_channels=16, |
| 86 | + out_channels=16, |
| 87 | + timestep_ratio_embedding_dim=64, |
| 88 | + patch_size=1, |
| 89 | + conditioning_dim=2048, |
| 90 | + block_out_channels=[2048, 2048], |
| 91 | + num_attention_heads=[32, 32], |
| 92 | + down_num_layers_per_block=[8, 24], |
| 93 | + up_num_layers_per_block=[24, 8], |
| 94 | + down_blocks_repeat_mappers=[1, 1], |
| 95 | + up_blocks_repeat_mappers=[1, 1], |
| 96 | + block_types_per_layer=[ |
| 97 | + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| 98 | + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| 99 | + ], |
| 100 | + clip_text_in_channels=1280, |
| 101 | + clip_text_pooled_in_channels=1280, |
| 102 | + clip_image_in_channels=768, |
| 103 | + clip_seq=4, |
| 104 | + kernel_size=3, |
| 105 | + dropout=[0.1, 0.1], |
| 106 | + self_attn=True, |
| 107 | + timestep_conditioning_type=["sca", "crp"], |
| 108 | + switch_level=[False], |
| 109 | + ) |
| 110 | +load_model_dict_into_meta(prior_model, state_dict) |
| 111 | + |
| 112 | +# scheduler for prior and decoder |
| 113 | +scheduler = DDPMWuerstchenScheduler() |
| 114 | + |
| 115 | +# Prior pipeline |
| 116 | +prior_pipeline = StableCascadePriorPipeline( |
| 117 | + prior=prior_model, |
| 118 | + tokenizer=tokenizer, |
| 119 | + text_encoder=text_encoder, |
| 120 | + image_encoder=image_encoder, |
| 121 | + scheduler=scheduler, |
| 122 | + feature_extractor=feature_extractor, |
| 123 | +) |
| 124 | +prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub) |
| 125 | + |
| 126 | +# Decoder |
| 127 | +if args.use_safetensors: |
| 128 | + orig_state_dict = load_file(decoder_checkpoint_path, device=device) |
| 129 | +else: |
| 130 | + orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device) |
| 131 | + |
| 132 | +state_dict = {} |
| 133 | +for key in orig_state_dict.keys(): |
| 134 | + if key.endswith("in_proj_weight"): |
| 135 | + weights = orig_state_dict[key].chunk(3, 0) |
| 136 | + state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] |
| 137 | + state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] |
| 138 | + state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] |
| 139 | + elif key.endswith("in_proj_bias"): |
| 140 | + weights = orig_state_dict[key].chunk(3, 0) |
| 141 | + state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] |
| 142 | + state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] |
| 143 | + state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] |
| 144 | + elif key.endswith("out_proj.weight"): |
| 145 | + weights = orig_state_dict[key] |
| 146 | + state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights |
| 147 | + elif key.endswith("out_proj.bias"): |
| 148 | + weights = orig_state_dict[key] |
| 149 | + state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights |
| 150 | + # rename clip_mapper to clip_txt_pooled_mapper |
| 151 | + elif key.endswith("clip_mapper.weight"): |
| 152 | + weights = orig_state_dict[key] |
| 153 | + state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights |
| 154 | + elif key.endswith("clip_mapper.bias"): |
| 155 | + weights = orig_state_dict[key] |
| 156 | + state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights |
| 157 | + else: |
| 158 | + state_dict[key] = orig_state_dict[key] |
| 159 | + |
| 160 | +with accelerate.init_empty_weights(): |
| 161 | + decoder = StableCascadeUNet( |
| 162 | + in_channels=4, |
| 163 | + out_channels=4, |
| 164 | + timestep_ratio_embedding_dim=64, |
| 165 | + patch_size=2, |
| 166 | + conditioning_dim=1280, |
| 167 | + block_out_channels=[320, 640, 1280, 1280], |
| 168 | + down_num_layers_per_block=[2, 6, 28, 6], |
| 169 | + up_num_layers_per_block=[6, 28, 6, 2], |
| 170 | + down_blocks_repeat_mappers=[1, 1, 1, 1], |
| 171 | + up_blocks_repeat_mappers=[3, 3, 2, 2], |
| 172 | + num_attention_heads=[0, 0, 20, 20], |
| 173 | + block_types_per_layer=[ |
| 174 | + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], |
| 175 | + ["SDCascadeResBlock", "SDCascadeTimestepBlock"], |
| 176 | + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| 177 | + ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| 178 | + ], |
| 179 | + clip_text_pooled_in_channels=1280, |
| 180 | + clip_seq=4, |
| 181 | + effnet_in_channels=16, |
| 182 | + pixel_mapper_in_channels=3, |
| 183 | + kernel_size=3, |
| 184 | + dropout=[0, 0, 0.1, 0.1], |
| 185 | + self_attn=True, |
| 186 | + timestep_conditioning_type=["sca"], |
| 187 | + ) |
| 188 | +load_model_dict_into_meta(decoder, state_dict) |
| 189 | + |
| 190 | +# VQGAN from Wuerstchen-V2 |
| 191 | +vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan") |
| 192 | + |
| 193 | +# Decoder pipeline |
| 194 | +decoder_pipeline = StableCascadeDecoderPipeline( |
| 195 | + decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler |
| 196 | +) |
| 197 | +decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub) |
| 198 | + |
| 199 | +# Stable Cascade combined pipeline |
| 200 | +stable_cascade_pipeline = StableCascadeCombinedPipeline( |
| 201 | + # Decoder |
| 202 | + text_encoder=text_encoder, |
| 203 | + tokenizer=tokenizer, |
| 204 | + decoder=decoder, |
| 205 | + scheduler=scheduler, |
| 206 | + vqgan=vqmodel, |
| 207 | + # Prior |
| 208 | + prior_text_encoder=text_encoder, |
| 209 | + prior_tokenizer=tokenizer, |
| 210 | + prior_prior=prior_model, |
| 211 | + prior_scheduler=scheduler, |
| 212 | + prior_image_encoder=image_encoder, |
| 213 | + prior_feature_extractor=feature_extractor, |
| 214 | +) |
| 215 | +stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub) |
0 commit comments