Skip to content

Commit 40aa47b

Browse files
kashifdome272pablopppsayakpaulpatrickvonplaten
authored
[Pipiline] Wuerstchen v3 aka Stable Cascasde pipeline (#6487)
* initial diffNext v3 * move to v3 folder * imports * dry up the unets * no switch_level * fix init * add switch_level tp config * Fixed some things * Added pooled text embeddings * Initial work on adding image encoder * changes from @dome272 * Stuff for the image encoder processing and variable naming in decoder * fix arg name * inference fixes * inference fixes * default TimestepBlock without conds * c_skip=0 by default * fix bfloat16 to cpu * use config * undo temp change * fix gen_c_embeddings args * change text encoding * text encoding * undo print * undo .gitignore change * Allow WuerstchenV3PriorPipeline to use the base DDPM & DDIM schedulers * use WuerstchenV3Unet in both pipelines * fix imports * initial failing tests * cleanup * use scheduler.timesterps * some fixes to the tests, still not fully working * fix tests * fix prior tests * add dropout to the model_kwargs * more tests passing * update expected_slice * initial rename * rename tests * rename class names * make fix-copies * initial docs * autodocs * typos * fix arg docs * add text_encoder info * combined pipeline has optional image arg * fix documentation * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: YiYi Xu <[email protected]> * use self.config * Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py Co-authored-by: YiYi Xu <[email protected]> * c_in -> in_channels * removed kwargs from unet's forward * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py Co-authored-by: Patrick von Platen <[email protected]> * remove older callback api * removed kwargs and fixed decoder guidance > 1 * decoder takes emeds * check and use image_embeds * fixed all but one decoder test * fix decoder tests * update callback api * fix some more combined tests * push combined pipeline * initial docs * fix doc_string * update combined api * no test_callback_inputs test for combined pipeline * add optional components * fix ordering of components * fix combined tests * update convert script * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py Co-authored-by: YiYi Xu <[email protected]> * fix imports * move effnet out of deniosing loop * prompt_embeds_pooled only when doing guidance * Fix repeat shape * move StableCascadeUnet to models/unets/ * more descriptive names * converted when numpy() * StableCascadePriorPipelineOutput docs * rename StableCascadeUNet * add slow tests * fix slow tests * update * update * updated model_path * add args for weights * set push_to_hub to false * update * update * update * update * update * update * update * update * update * update * update * update * update * update --------- Co-authored-by: Dominic Rampas <[email protected]> Co-authored-by: Pablo Pernias <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: 99991 <[email protected]> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 1bc0d37 commit 40aa47b

22 files changed

+3214
-25
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@
318318
title: Semantic Guidance
319319
- local: api/pipelines/shap_e
320320
title: Shap-E
321+
- local: api/pipelines/stable_cascade
322+
title: Stable Cascade
321323
- sections:
322324
- local: api/pipelines/stable_diffusion/overview
323325
title: Overview
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Stable Cascade
14+
15+
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
16+
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
17+
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
18+
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
19+
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
20+
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
21+
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
22+
Diffusion 1.5.
23+
24+
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
25+
like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.
26+
27+
The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).
28+
29+
## Model Overview
30+
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
31+
hence the name "Stable Cascade".
32+
33+
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
34+
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
35+
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
36+
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
37+
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
38+
for generating the small 24 x 24 latents given a text prompt.
39+
40+
## Uses
41+
42+
### Direct Use
43+
44+
The model is intended for research purposes for now. Possible research areas and tasks include
45+
46+
- Research on generative models.
47+
- Safe deployment of models which have the potential to generate harmful content.
48+
- Probing and understanding the limitations and biases of generative models.
49+
- Generation of artworks and use in design and other artistic processes.
50+
- Applications in educational or creative tools.
51+
52+
Excluded uses are described below.
53+
54+
### Out-of-Scope Use
55+
56+
The model was not trained to be factual or true representations of people or events,
57+
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
58+
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
59+
60+
## Limitations and Bias
61+
62+
### Limitations
63+
- Faces and people in general may not be generated properly.
64+
- The autoencoding part of the model is lossy.
65+
66+
67+
## StableCascadeCombinedPipeline
68+
69+
[[autodoc]] StableCascadeCombinedPipeline
70+
- all
71+
- __call__
72+
73+
## StableCascadePriorPipeline
74+
75+
[[autodoc]] StableCascadePriorPipeline
76+
- all
77+
- __call__
78+
79+
## StableCascadePriorPipelineOutput
80+
81+
[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput
82+
83+
## StableCascadeDecoderPipeline
84+
85+
[[autodoc]] StableCascadeDecoderPipeline
86+
- all
87+
- __call__
88+

scripts/convert_stable_cascade.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
2+
import argparse
3+
4+
import accelerate
5+
import torch
6+
from safetensors.torch import load_file
7+
from transformers import (
8+
AutoTokenizer,
9+
CLIPConfig,
10+
CLIPImageProcessor,
11+
CLIPTextModelWithProjection,
12+
CLIPVisionModelWithProjection,
13+
)
14+
15+
from diffusers import (
16+
DDPMWuerstchenScheduler,
17+
StableCascadeCombinedPipeline,
18+
StableCascadeDecoderPipeline,
19+
StableCascadePriorPipeline,
20+
)
21+
from diffusers.models import StableCascadeUNet
22+
from diffusers.models.modeling_utils import load_model_dict_into_meta
23+
from diffusers.pipelines.wuerstchen import PaellaVQModel
24+
25+
26+
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
27+
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
28+
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
29+
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
30+
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
31+
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
32+
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
33+
34+
args = parser.parse_args()
35+
model_path = args.model_path
36+
37+
device = "cpu"
38+
39+
# set paths to model weights
40+
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
41+
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
42+
43+
# Clip Text encoder and tokenizer
44+
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
45+
config.text_config.projection_dim = config.projection_dim
46+
text_encoder = CLIPTextModelWithProjection.from_pretrained(
47+
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
48+
)
49+
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
50+
51+
# image processor
52+
feature_extractor = CLIPImageProcessor()
53+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
54+
55+
# Prior
56+
if args.use_safetensors:
57+
orig_state_dict = load_file(prior_checkpoint_path, device=device)
58+
else:
59+
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
60+
61+
state_dict = {}
62+
for key in orig_state_dict.keys():
63+
if key.endswith("in_proj_weight"):
64+
weights = orig_state_dict[key].chunk(3, 0)
65+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
66+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
67+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
68+
elif key.endswith("in_proj_bias"):
69+
weights = orig_state_dict[key].chunk(3, 0)
70+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
71+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
72+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
73+
elif key.endswith("out_proj.weight"):
74+
weights = orig_state_dict[key]
75+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
76+
elif key.endswith("out_proj.bias"):
77+
weights = orig_state_dict[key]
78+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
79+
else:
80+
state_dict[key] = orig_state_dict[key]
81+
82+
83+
with accelerate.init_empty_weights():
84+
prior_model = StableCascadeUNet(
85+
in_channels=16,
86+
out_channels=16,
87+
timestep_ratio_embedding_dim=64,
88+
patch_size=1,
89+
conditioning_dim=2048,
90+
block_out_channels=[2048, 2048],
91+
num_attention_heads=[32, 32],
92+
down_num_layers_per_block=[8, 24],
93+
up_num_layers_per_block=[24, 8],
94+
down_blocks_repeat_mappers=[1, 1],
95+
up_blocks_repeat_mappers=[1, 1],
96+
block_types_per_layer=[
97+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
98+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
99+
],
100+
clip_text_in_channels=1280,
101+
clip_text_pooled_in_channels=1280,
102+
clip_image_in_channels=768,
103+
clip_seq=4,
104+
kernel_size=3,
105+
dropout=[0.1, 0.1],
106+
self_attn=True,
107+
timestep_conditioning_type=["sca", "crp"],
108+
switch_level=[False],
109+
)
110+
load_model_dict_into_meta(prior_model, state_dict)
111+
112+
# scheduler for prior and decoder
113+
scheduler = DDPMWuerstchenScheduler()
114+
115+
# Prior pipeline
116+
prior_pipeline = StableCascadePriorPipeline(
117+
prior=prior_model,
118+
tokenizer=tokenizer,
119+
text_encoder=text_encoder,
120+
image_encoder=image_encoder,
121+
scheduler=scheduler,
122+
feature_extractor=feature_extractor,
123+
)
124+
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
125+
126+
# Decoder
127+
if args.use_safetensors:
128+
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
129+
else:
130+
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
131+
132+
state_dict = {}
133+
for key in orig_state_dict.keys():
134+
if key.endswith("in_proj_weight"):
135+
weights = orig_state_dict[key].chunk(3, 0)
136+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
137+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
138+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
139+
elif key.endswith("in_proj_bias"):
140+
weights = orig_state_dict[key].chunk(3, 0)
141+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
142+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
143+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
144+
elif key.endswith("out_proj.weight"):
145+
weights = orig_state_dict[key]
146+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
147+
elif key.endswith("out_proj.bias"):
148+
weights = orig_state_dict[key]
149+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
150+
# rename clip_mapper to clip_txt_pooled_mapper
151+
elif key.endswith("clip_mapper.weight"):
152+
weights = orig_state_dict[key]
153+
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
154+
elif key.endswith("clip_mapper.bias"):
155+
weights = orig_state_dict[key]
156+
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
157+
else:
158+
state_dict[key] = orig_state_dict[key]
159+
160+
with accelerate.init_empty_weights():
161+
decoder = StableCascadeUNet(
162+
in_channels=4,
163+
out_channels=4,
164+
timestep_ratio_embedding_dim=64,
165+
patch_size=2,
166+
conditioning_dim=1280,
167+
block_out_channels=[320, 640, 1280, 1280],
168+
down_num_layers_per_block=[2, 6, 28, 6],
169+
up_num_layers_per_block=[6, 28, 6, 2],
170+
down_blocks_repeat_mappers=[1, 1, 1, 1],
171+
up_blocks_repeat_mappers=[3, 3, 2, 2],
172+
num_attention_heads=[0, 0, 20, 20],
173+
block_types_per_layer=[
174+
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
175+
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
176+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
177+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
178+
],
179+
clip_text_pooled_in_channels=1280,
180+
clip_seq=4,
181+
effnet_in_channels=16,
182+
pixel_mapper_in_channels=3,
183+
kernel_size=3,
184+
dropout=[0, 0, 0.1, 0.1],
185+
self_attn=True,
186+
timestep_conditioning_type=["sca"],
187+
)
188+
load_model_dict_into_meta(decoder, state_dict)
189+
190+
# VQGAN from Wuerstchen-V2
191+
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
192+
193+
# Decoder pipeline
194+
decoder_pipeline = StableCascadeDecoderPipeline(
195+
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
196+
)
197+
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
198+
199+
# Stable Cascade combined pipeline
200+
stable_cascade_pipeline = StableCascadeCombinedPipeline(
201+
# Decoder
202+
text_encoder=text_encoder,
203+
tokenizer=tokenizer,
204+
decoder=decoder,
205+
scheduler=scheduler,
206+
vqgan=vqmodel,
207+
# Prior
208+
prior_text_encoder=text_encoder,
209+
prior_tokenizer=tokenizer,
210+
prior_prior=prior_model,
211+
prior_scheduler=scheduler,
212+
prior_image_encoder=image_encoder,
213+
prior_feature_extractor=feature_extractor,
214+
)
215+
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)

src/diffusers/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"MotionAdapter",
8787
"MultiAdapter",
8888
"PriorTransformer",
89+
"StableCascadeUNet",
8990
"T2IAdapter",
9091
"T5FilmDecoder",
9192
"Transformer2DModel",
@@ -259,6 +260,9 @@
259260
"SemanticStableDiffusionPipeline",
260261
"ShapEImg2ImgPipeline",
261262
"ShapEPipeline",
263+
"StableCascadeCombinedPipeline",
264+
"StableCascadeDecoderPipeline",
265+
"StableCascadePriorPipeline",
262266
"StableDiffusionAdapterPipeline",
263267
"StableDiffusionAttendAndExcitePipeline",
264268
"StableDiffusionControlNetImg2ImgPipeline",
@@ -626,6 +630,9 @@
626630
SemanticStableDiffusionPipeline,
627631
ShapEImg2ImgPipeline,
628632
ShapEPipeline,
633+
StableCascadeCombinedPipeline,
634+
StableCascadeDecoderPipeline,
635+
StableCascadePriorPipeline,
629636
StableDiffusionAdapterPipeline,
630637
StableDiffusionAttendAndExcitePipeline,
631638
StableDiffusionControlNetImg2ImgPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
4848
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
4949
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
50+
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
5051
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
5152
_import_structure["vq_model"] = ["VQModel"]
5253

@@ -80,6 +81,7 @@
8081
I2VGenXLUNet,
8182
Kandinsky3UNet,
8283
MotionAdapter,
84+
StableCascadeUNet,
8385
UNet1DModel,
8486
UNet2DConditionModel,
8587
UNet2DModel,

src/diffusers/models/unets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .unet_kandinsky3 import Kandinsky3UNet
1111
from .unet_motion_model import MotionAdapter, UNetMotionModel
1212
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
13+
from .unet_stable_cascade import StableCascadeUNet
1314
from .uvit_2d import UVit2DModel
1415

1516

0 commit comments

Comments
 (0)