Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support HunYuan DiT #1378

Draft
wants to merge 21 commits into
base: dev
Choose a base branch
from
Draft

Conversation

KohakuBlueleaf
Copy link
Contributor

@KohakuBlueleaf KohakuBlueleaf commented Jun 21, 2024

[WIP] This PR is a draft PR for contributors to check the progress and review the codes.

This PR starts with a simple implementation by me for minimal inference and some modifications:

  1. modify the initialize method of HunYuanDiT to avoid the requirements of argparse
  2. replace the flash_atth with pytorch sdp and xformers implementation.
  3. implement the gradient checkpointing mechanism to save TONS OF VRAM.
  4. support "CLIP concat" trick for long prompt.
    • Need review by HunYuan team. Should work as original one with max_length_clip=77
  5. a test script for quick check on inference.
    • I didn't follow the style of xxx_minimal_inference. So I called it hunyuan_test.py. But it can be seen as a minimal inference script

Notes about loading model

The directory structure I used is:

model/
  clip/
  denoiser/
  mt5/
  vae/

basically download files from the t2i folder of HunYuanDiT

and put the content of clip_text_encoder and tokenizer into clip.
put mt5 into mt5, put model into denoiser, put sdxl-vae-fp16-fix into vae

This spec can be changed if needed.

TODO List

  • Examine the current implementation on modified part
  • Bundle format support (if possible)
  • training utils (if needed)
    • *Tokenizer/TE related for Dataset
  • training script (modify from sdxl_train.py)
  • *lora/lycoris training script (modify from sdxl_train_network.py)
    • Initial Support.
    • Unique training arg region.
    • Implementation check.
  • *lora module supports
    • kohya lora
    • LyCORIS

Low Priority TODO List

  • cache TE embeddings.

Notification to contributors

  • You can assume the create_network method from imported network module will work correctly.
    • Kohya and I will ensure that.
  • Check sdxl_train.py and sdxl_train_network.py and the dataset things carefully before starting development. It is very likely that we only need few modification to make things work. Try to avoid any "fully rework".
  • If you want to contribute to this PR, open another PR into this branch: https://github.com/KohakuBlueleaf/sd-scripts/tree/HunYuanDiT
    • I will check all the related PR/issue frequently in this week

@KohakuBlueleaf
Copy link
Contributor Author

For anyone who want to try HunYuan but don't want to download the original 44GB files:
https://huggingface.co/KBlueLeaf/HunYuanDiT-V1.1-fp16-pruned

8GB only here.

@KohakuBlueleaf
Copy link
Contributor Author

after commit #fb3e8a7
We have LyCORIS/LoRA training usable. (still need some checks on implementation detail, but it works)
image

Some functionality is not usable at this moment and will be fixed in the future.

@KohakuBlueleaf
Copy link
Contributor Author

Requirements for lora/lycoris training on HunYuan:

  • lycoris-lora>=3.0.0.dev10
  • 9GB vram for train unet only(train dit only)
  • 12GB vram for train unet(dit) + train TE
    • above 2 requirements are under cache latents and gradient checkpointing, bs1
    • cache TE is not enabled
    • full bf16/fp16

@KohakuBlueleaf
Copy link
Contributor Author

currently fp16(mixed) will cause nan loss
will check which part goes wrong.

@KohakuBlueleaf
Copy link
Contributor Author

FP16 fixed

@KohakuBlueleaf
Copy link
Contributor Author

I use umamusume dataset (with danbooru tag prompt format) to train HunYuan DiT V1.1 with bs8 600step. (Train DiT only)
Looks like my implementation is ok.

Original With LoKr (12MB, bs8, 600step)
test-no-lokr test-lokr

@KohakuBlueleaf
Copy link
Contributor Author

KohakuBlueleaf commented Jun 24, 2024

For those who want to try HunYuan training.
here is an example script:
image

@sdbds
Copy link
Contributor

sdbds commented Jun 26, 2024

← original dataset →test result
image
image

xljh0520 and others added 5 commits June 28, 2024 22:58
* support hunyuan lora in CLIP text encoder and DiT blocks

* add hunyuan lora test script

* append lora blocks in target module

---------

Co-authored-by: leoriojhli <[email protected]>
* support hunyuan lora in CLIP text encoder and DiT blocks

* add hunyuan lora test script

* append lora blocks in target module

* Support HunYuanDiT v1.1 and v1.2 lora

---------

Co-authored-by: leoriojhli <[email protected]>
* add use_extra_cond for hy_train_network

* change model version

* Update hunyuan_train_network.py
* add use_extra_cond for hy_train_network

* change model version

* Update hunyuan_train_network.py

* Update hunyuan_train_network.py

* Update hunyuan_train.py
@tristanwqy
Copy link

If HunyuanDIT also uses the VAE from sdxl, does that mean the prepare bucket latents can reuse the data from sdxl?

@KohakuBlueleaf
Copy link
Contributor Author

If HunyuanDIT also uses the VAE from sdxl, does that mean the prepare bucket latents can reuse the data from sdxl?

Yes, and kohya's latent caching will only check the size.
So you can use the dataset folder which already have cached latent

@tristanwqy
Copy link

If HunyuanDIT also uses the VAE from sdxl, does that mean the prepare bucket latents can reuse the data from sdxl?

Yes, and kohya's latent caching will only check the size. So you can use the dataset folder which already have cached latent

OK, Thanks

@tristanwqy
Copy link

I was trying to run the code below in libray/hunyuan_models

    root = "/workspace/models/hunyuan/HunYuanDiT-V1.2-fp16-pruned/"
    denoiser, patch_size, head_dim = DiT_g_2(input_size=(128, 128))
    sd = torch.load(os.path.join(root, "denoiser/pytorch_model_module.pt"))
    denoiser.load_state_dict(sd)
    denoiser.half().cuda()
    denoiser.enable_gradient_checkpointing()

    clip_tokenizer = AutoTokenizer.from_pretrained(os.path.join(root, "clip"))
    clip_encoder = BertModel.from_pretrained(os.path.join(root, "clip")).half().cuda()

    mt5_embedder = MT5Embedder(os.path.join(root, "mt5"), torch_dtype=torch.float16, max_length=256)

    vae = AutoencoderKL.from_pretrained(os.path.join(root, "vae")).half().cuda()

    print(sum(p.numel() for p in denoiser.parameters()) / 1e6)
    print(sum(p.numel() for p in mt5_embedder.parameters()) / 1e6)
    print(sum(p.numel() for p in clip_encoder.parameters()) / 1e6)
    print(sum(p.numel() for p in vae.parameters()) / 1e6)

but failed with

    Use xformers attention implementation.
    Number of tokens: 4096
Traceback (most recent call last):
  File "/home/ubuntu/sd-scripts/library/hunyuan_models.py", line 1287, in <module>
    denoiser.load_state_dict(sd)
  File "/home/ubuntu/miniconda3/envs/training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HunYuanDiT:
	Unexpected key(s) in state_dict: "style_embedder.weight". 
	size mismatch for extra_embedder.0.weight: copying a param with shape torch.Size([5632, 3968]) from checkpoint, the shape in current model is torch.Size([5632, 1024]).

@sdbds
Copy link
Contributor

sdbds commented Jul 4, 2024

I was trying to run the code below in libray/hunyuan_models

    root = "/workspace/models/hunyuan/HunYuanDiT-V1.2-fp16-pruned/"
    denoiser, patch_size, head_dim = DiT_g_2(input_size=(128, 128))
    sd = torch.load(os.path.join(root, "denoiser/pytorch_model_module.pt"))
    denoiser.load_state_dict(sd)
    denoiser.half().cuda()
    denoiser.enable_gradient_checkpointing()

    clip_tokenizer = AutoTokenizer.from_pretrained(os.path.join(root, "clip"))
    clip_encoder = BertModel.from_pretrained(os.path.join(root, "clip")).half().cuda()

    mt5_embedder = MT5Embedder(os.path.join(root, "mt5"), torch_dtype=torch.float16, max_length=256)

    vae = AutoencoderKL.from_pretrained(os.path.join(root, "vae")).half().cuda()

    print(sum(p.numel() for p in denoiser.parameters()) / 1e6)
    print(sum(p.numel() for p in mt5_embedder.parameters()) / 1e6)
    print(sum(p.numel() for p in clip_encoder.parameters()) / 1e6)
    print(sum(p.numel() for p in vae.parameters()) / 1e6)

but failed with

    Use xformers attention implementation.
    Number of tokens: 4096
Traceback (most recent call last):
  File "/home/ubuntu/sd-scripts/library/hunyuan_models.py", line 1287, in <module>
    denoiser.load_state_dict(sd)
  File "/home/ubuntu/miniconda3/envs/training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HunYuanDiT:
	Unexpected key(s) in state_dict: "style_embedder.weight". 
	size mismatch for extra_embedder.0.weight: copying a param with shape torch.Size([5632, 3968]) from checkpoint, the shape in current model is torch.Size([5632, 1024]).

1.2 delete style_embedder.weight, wait they fix it.

@KohakuBlueleaf
Copy link
Contributor Author

I was trying to run the code below in libray/hunyuan_models

    root = "/workspace/models/hunyuan/HunYuanDiT-V1.2-fp16-pruned/"
    denoiser, patch_size, head_dim = DiT_g_2(input_size=(128, 128))
    sd = torch.load(os.path.join(root, "denoiser/pytorch_model_module.pt"))
    denoiser.load_state_dict(sd)
    denoiser.half().cuda()
    denoiser.enable_gradient_checkpointing()

    clip_tokenizer = AutoTokenizer.from_pretrained(os.path.join(root, "clip"))
    clip_encoder = BertModel.from_pretrained(os.path.join(root, "clip")).half().cuda()

    mt5_embedder = MT5Embedder(os.path.join(root, "mt5"), torch_dtype=torch.float16, max_length=256)

    vae = AutoencoderKL.from_pretrained(os.path.join(root, "vae")).half().cuda()

    print(sum(p.numel() for p in denoiser.parameters()) / 1e6)
    print(sum(p.numel() for p in mt5_embedder.parameters()) / 1e6)
    print(sum(p.numel() for p in clip_encoder.parameters()) / 1e6)
    print(sum(p.numel() for p in vae.parameters()) / 1e6)

but failed with

    Use xformers attention implementation.
    Number of tokens: 4096
Traceback (most recent call last):
  File "/home/ubuntu/sd-scripts/library/hunyuan_models.py", line 1287, in <module>
    denoiser.load_state_dict(sd)
  File "/home/ubuntu/miniconda3/envs/training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HunYuanDiT:
	Unexpected key(s) in state_dict: "style_embedder.weight". 
	size mismatch for extra_embedder.0.weight: copying a param with shape torch.Size([5632, 3968]) from checkpoint, the shape in current model is torch.Size([5632, 1024]).

1.2 delete style_embedder.weight, wait they fix it.

I think the problem is we need --extra_cond arg for v1.0/1.1 to enable extra cond

Not sure if you have implemented this into train network script
But full training should work with that arg

@tristanwqy
Copy link

tristanwqy commented Jul 5, 2024

I was trying to run the code below in libray/hunyuan_models

    root = "/workspace/models/hunyuan/HunYuanDiT-V1.2-fp16-pruned/"
    denoiser, patch_size, head_dim = DiT_g_2(input_size=(128, 128))
    sd = torch.load(os.path.join(root, "denoiser/pytorch_model_module.pt"))
    denoiser.load_state_dict(sd)
    denoiser.half().cuda()
    denoiser.enable_gradient_checkpointing()

    clip_tokenizer = AutoTokenizer.from_pretrained(os.path.join(root, "clip"))
    clip_encoder = BertModel.from_pretrained(os.path.join(root, "clip")).half().cuda()

    mt5_embedder = MT5Embedder(os.path.join(root, "mt5"), torch_dtype=torch.float16, max_length=256)

    vae = AutoencoderKL.from_pretrained(os.path.join(root, "vae")).half().cuda()

    print(sum(p.numel() for p in denoiser.parameters()) / 1e6)
    print(sum(p.numel() for p in mt5_embedder.parameters()) / 1e6)
    print(sum(p.numel() for p in clip_encoder.parameters()) / 1e6)
    print(sum(p.numel() for p in vae.parameters()) / 1e6)

but failed with

    Use xformers attention implementation.
    Number of tokens: 4096
Traceback (most recent call last):
  File "/home/ubuntu/sd-scripts/library/hunyuan_models.py", line 1287, in <module>
    denoiser.load_state_dict(sd)
  File "/home/ubuntu/miniconda3/envs/training/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HunYuanDiT:
	Unexpected key(s) in state_dict: "style_embedder.weight". 
	size mismatch for extra_embedder.0.weight: copying a param with shape torch.Size([5632, 3968]) from checkpoint, the shape in current model is torch.Size([5632, 1024]).

1.2 delete style_embedder.weight, wait they fix it.

I think the problem is we need --extra_cond arg for v1.0/1.1 to enable extra cond

Not sure if you have implemented this into train network script But full training should work with that arg

v1.2 with enable extra cond = True works, but the image generated looks a little bit weird, and that doesn't make sense.
On the other hand, V1.1 works perfectly fine

@KohakuBlueleaf
Copy link
Contributor Author

You should disable extra cond for v1.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants