Skip to content

Conversation

@linoytsaban
Copy link
Collaborator

add text encoder training support for the CLIP encoders to the dreambooth lora training script for SD3

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@linoytsaban linoytsaban marked this pull request as ready for review June 19, 2024 00:47
@r-aristov
Copy link

When no individual prompts provided, only instance prompts, got this:

Traceback (most recent call last):
  File "/root/train_dreambooth_lora_sd3.py", line 1836, in <module>
    main(args)
  File "/root/train_dreambooth_lora_sd3.py", line 1631, in main
    encoder_hidden_states=prompt_embeds,
UnboundLocalError: local variable 'prompt_embeds' referenced before assignment

@linoytsaban
Copy link
Collaborator Author

thanks @r-aristov! I think it should be working now

@yiyixuxu yiyixuxu requested a review from sayakpaul June 19, 2024 22:38
Comment on lines +1637 to +1641
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))

if text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we confirm this via experiments that text encoder 3 training doesn't matter too much? Can be done separately and won't block this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we skipped not because we think it doesn't matter but because we thought dealing with training the T5 would be a different animal than the already well known CLIP text encoder training (also on VRAM consumption side). So indeed we left it to a future PR to investigate the T5 training!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly worried about the dynamics about this so, let’s make sure we run ample experiments to see if training two text encoders while keeping the other one fixed works as expected.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have left some comments. My main question is how much does training the text encoder matter here given we use three in SD3? Could we see some concrete comparative examples?

Additionally, we need to add tests to https://github.com/huggingface/diffusers/blob/main/tests/lora/test_lora_layers_sd3.py and add a note about --train_text_encoder in the REAMDE.

@linoytsaban
Copy link
Collaborator Author

some results for comparison
Group 1-20
training config:

!accelerate launch train_dreambooth_lora_sd3.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-3-medium-diffusers"  \
  --dataset_name="Norod78/Yarn-art-style"\
  --output_dir="dreambooth-sd3-lora"\
  --mixed_precision="fp16" \
  --instance_prompt="a photo of TOK yarn art dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  **--train_text_encoder\**
  --gradient_accumulation_steps=1 \
  --optimizer="prodigy"\
  --learning_rate=1.0 \
  **--text_encoder_lr=1.0\**
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1500 \
  --repeats=1\
  --rank=32\
  --weighting_scheme="logit_normal" \
  --validation_epochs=100 \
  --seed="0" \
  --push_to_hub

@sayakpaul
Copy link
Member

sayakpaul commented Jun 24, 2024

Cool, the results are stunning. So, the TODOs are:

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work!

@sayakpaul sayakpaul merged commit c6e08ec into huggingface:main Jun 25, 2024
@linoytsaban linoytsaban deleted the sd3-dreambooth-lora branch July 1, 2024 09:42
sayakpaul added a commit that referenced this pull request Dec 23, 2024
…#8630)

* add clip text-encoder training

* no dora

* text encoder traing fixes

* text encoder traing fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* add text_encoder layers to save_lora

* style

* fix imports

* style

* fix text encoder

* review changes

* review changes

* review changes

* minor change

* add lora tag

* style

* add readme notes

* add tests for clip encoders

* style

* typo

* fixes

* style

* Update tests/lora/test_lora_layers_sd3.py

Co-authored-by: Sayak Paul <[email protected]>

* Update examples/dreambooth/README_sd3.md

Co-authored-by: Sayak Paul <[email protected]>

* minor readme change

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants