-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[Refactor] Stable Diffusion UNet Refactor Proposal #6455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| ) | ||
| self.encoder_hid_proj = None | ||
|
|
||
| if class_embed_type == "timestep": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used in Stable UnCLIP but IMO can be removed.
| # 1. time | ||
| timesteps = timestep | ||
| if not torch.is_tensor(timesteps): | ||
| # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | ||
| # This would be a good case for the `match` statement (Python 3.10+) | ||
| is_mps = sample.device.type == "mps" | ||
| if isinstance(timestep, float): | ||
| dtype = torch.float32 if is_mps else torch.float64 | ||
| else: | ||
| dtype = torch.int32 if is_mps else torch.int64 | ||
| timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | ||
| elif len(timesteps.shape) == 0: | ||
| timesteps = timesteps[None].to(sample.device) | ||
|
|
||
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | ||
| timesteps = timesteps.expand(sample.shape[0]) | ||
|
|
||
| t_emb = self.time_proj(timesteps) | ||
|
|
||
| # `Timesteps` does not contain any weights and will always return f32 tensors | ||
| # but time_embedding might actually be running in fp16. so we need to cast here. | ||
| # there might be better ways to encapsulate this. | ||
| t_emb = t_emb.to(dtype=sample.dtype) | ||
|
|
||
| emb = self.time_embedding(t_emb, timestep_cond) | ||
| aug_emb = None | ||
|
|
||
| if self.class_embedding is not None: | ||
| if class_labels is None: | ||
| raise ValueError("class_labels should be provided when class_embedding_type is not None") | ||
|
|
||
| if self.config.class_embed_type == "timestep": | ||
| class_labels = self.time_proj(class_labels) | ||
|
|
||
| # `Timesteps` does not contain any weights and will always return f32 tensors | ||
| # there might be better ways to encapsulate this. | ||
| class_labels = class_labels.to(dtype=sample.dtype) | ||
|
|
||
| class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) | ||
| emb = emb + class_emb | ||
|
|
||
| if self.config.addition_embed_type == "text_time": | ||
| # SDXL - style | ||
| if "text_embeds" not in added_cond_kwargs: | ||
| raise ValueError( | ||
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | ||
| ) | ||
| text_embeds = added_cond_kwargs.get("text_embeds") | ||
| if "time_ids" not in added_cond_kwargs: | ||
| raise ValueError( | ||
| f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | ||
| ) | ||
| time_ids = added_cond_kwargs.get("time_ids") | ||
| time_embeds = self.add_time_proj(time_ids.flatten()) | ||
| time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | ||
| add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | ||
| add_embeds = add_embeds.to(emb.dtype) | ||
| aug_emb = self.add_embedding(add_embeds) | ||
|
|
||
| emb = emb + aug_emb if aug_emb is not None else emb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this could be cleaned up a bit.
| lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 | ||
| if USE_PEFT_BACKEND: | ||
| # weight the lora layers by setting `lora_scale` for each PEFT layer | ||
| scale_lora_layers(self, lora_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checks will be removed. We can assume that this UNet should only use PEFT
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I am not okay with this philosophy. SSD-1B is a direct distillation of SDXL. So, family-wise they share 99% of the similarities. I don't think SSD-1B should get its own UNet. It might create a bit too much of a modularity which we might not wat.
This works for me. But we have to do this in a new way that doesn't compromise with the linearity of how models are built in
100 percent. GLIGEN has got lots of different components and it deserves its own UNet. For pt 4. and 5., we still need to find a way to support them even if via separate UNet, right? Perhaps, judging via the parameters that are less used and clubbing them under a particular config to have a dedicated class?
Since SD models support ControlNet, T2I Adapters, and IP Adapters, it might be difficult to fully remove them no? |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
|
|
||
| class SDCrossAttnDownBlock2D(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to keep these blocks in this file? I thought the higher-level UNet implementations would only collate lower-level blocks while them residing elsehwere.
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Generally looks good to me. Some comments:
I don't think the asymmetric UNet should have it's own class. It should be supported here IMO. It a common use case by now and we have multiple powerful checkpoints by segmind that are asymmetric
Agree - do we have a checkpoint of a model that doesn't have a midblock?
Agree.
Is it just Stable UnCLIP? Did you also check all the reasonably often used upscaling unets (StableLatentUpscaling, StableUpscaling, ...) ? |
The Latent Upscaler uses different block types and as well as a The StableDiffusionUpscaler does use |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
The The UNet2DConditionModel has become quite large is starting to become quite difficult to understand and maintain. In order to simplify this a bit, we can try to break it up into smaller, model specific UNets. This is a draft of the new proposed design.
Some design considerations:
Known issues with this design:
TODOs:
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.