Skip to content

Conversation

@DN6
Copy link
Collaborator

@DN6 DN6 commented Jan 4, 2024

What does this PR do?

The The UNet2DConditionModel has become quite large is starting to become quite difficult to understand and maintain. In order to simplify this a bit, we can try to break it up into smaller, model specific UNets. This is a draft of the new proposed design.

Some design considerations:

  1. This UNet is designed to work with the following model types: Sd 1.5 and SDXL (including inpainting), Stable UnCLIP models.
  2. Parameters that are not relevant to these model types have been removed. We can further reduce the number of parameters of this UNet if we decide that SDXL and Stable UnCLIP should get their own UNets.

Known issues with this design:

  1. Any SD Pipeline using an Asymmetric UNet i.e. uses the reverse_transformer_layers_per_block parameter. Segmind UNets fall under this category and I think they should get their own UNet.
  2. UNet that doesn't use a Cross Attn Midblock or that does not have a mid block. IMO I think we can add the option to support multiple midblock types/no midblock.
  3. GLIGEN UNet that relies on additional model components. This should be it's own UNet.
  4. UNets with dual cross attention (these have almost no downloads)
  5. UNets that use class_embedding_concat . There isn't much usage for this parameter. And class embeddings are mainly used in the Stable UnCLIP UNet that just adds them to the time embedding

TODOs:

  1. I think the time embedding logic can be cleaned up a bit. It's a little confusing
  2. ControlNet and Adapter logic is still in this UNet. We have to decide how to handle them (whether to remove or not)
  3. PEFT backend checks are in there, but should be removed. Ideally this new design will be released after we deprecate the old lora backend.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

)
self.encoder_hid_proj = None

if class_embed_type == "timestep":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used in Stable UnCLIP but IMO can be removed.

Comment on lines 1116 to 1175
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])

t_emb = self.time_proj(timesteps)

# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)

emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None

if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when class_embedding_type is not None")

if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)

# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)

class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
emb = emb + class_emb

if self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

emb = emb + aug_emb if aug_emb is not None else emb
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this could be cleaned up a bit.

Comment on lines 1181 to 1184
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks will be removed. We can assume that this UNet should only use PEFT

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

Any SD Pipeline using an Asymmetric UNet i.e., uses the reverse_transformer_layers_per_block parameter. Segmind UNets fall under this category and I think they should get their own UNet.

I am not okay with this philosophy. SSD-1B is a direct distillation of SDXL. So, family-wise they share 99% of the similarities. I don't think SSD-1B should get its own UNet. It might create a bit too much of a modularity which we might not wat.

UNet that doesn't use a Cross Attn Midblock or that does not have a mid block. IMO I think we can add the option to support multiple midblock types/no midblock.

This works for me. But we have to do this in a new way that doesn't compromise with the linearity of how models are built in diffusers.

GLIGEN UNet that relies on additional model components. This should be it's own UNet.

100 percent. GLIGEN has got lots of different components and it deserves its own UNet.

For pt 4. and 5., we still need to find a way to support them even if via separate UNet, right? Perhaps, judging via the parameters that are less used and clubbing them under a particular config to have a dedicated class?

ControlNet and Adapter logic is still in this UNet. We have to decide how to handle them (whether to remove or not)

Since SD models support ControlNet, T2I Adapters, and IP Adapters, it might be difficult to fully remove them no?

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class SDCrossAttnDownBlock2D(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to keep these blocks in this file? I thought the higher-level UNet implementations would only collate lower-level blocks while them residing elsehwere.

@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 3, 2024
@DN6 DN6 removed the stale Issues that haven't received updates label Feb 5, 2024
@DN6 DN6 marked this pull request as ready for review February 5, 2024 14:27
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Feb 9, 2024

Generally looks good to me. Some comments:

Any SD Pipeline using an Asymmetric UNet i.e. uses the reverse_transformer_layers_per_block parameter. Segmind UNets fall under this category and I think they should get their own UNet.

I don't think the asymmetric UNet should have it's own class. It should be supported here IMO. It a common use case by now and we have multiple powerful checkpoints by segmind that are asymmetric

  1. UNet that doesn't use a Cross Attn Midblock or that does not have a mid block. IMO I think we can add the option to support multiple midblock types/no midblock.

Agree - do we have a checkpoint of a model that doesn't have a midblock?

  1. GLIGEN UNet that relies on additional model components. This should be it's own UNet.
  2. UNets with dual cross attention (these have almost no downloads)

Agree.

  1. UNets that use class_embedding_concat . There isn't much usage for this parameter. And class embeddings are mainly used in the Stable UnCLIP UNet that just adds them to the time embedding

Is it just Stable UnCLIP? Did you also check all the reasonably often used upscaling unets (StableLatentUpscaling, StableUpscaling, ...) ?

@DN6
Copy link
Collaborator Author

DN6 commented Feb 15, 2024

Is it just Stable UnCLIP? Did you also check all the reasonably often used upscaling unets (StableLatentUpscaling, StableUpscaling, ...) ?

The Latent Upscaler uses different block types and as well as a fourier time_embedding_type and uses arguments like resnet_time_scale_shift and timestep_post_act_fn. So maybe it should be its own UNet?

The StableDiffusionUpscaler does use num_class_embeds as an argument. So I will add that back to this UNet design.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Mar 10, 2024
@yiyixuxu yiyixuxu added wip and removed stale Issues that haven't received updates wip labels Mar 11, 2024
@github-actions
Copy link
Contributor

github-actions bot commented Apr 5, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants