Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jul 10, 2023
2 parents b762ed2 + f54b784 commit 1ba606c
Show file tree
Hide file tree
Showing 5 changed files with 809 additions and 468 deletions.
30 changes: 26 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,31 @@ The feature of SDXL training is now available in sdxl branch as an experimental
Summary of the feature:

- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. This option enables the full bfloat16 training. This option is useful to reduce the GPU memory usage.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
- `--weighted_captions` option is not supported yet.

- `--weighted_captions` option is not supported yet for both scripts.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.

- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.

- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.

`requirements.txt` is updated to support SDXL training.

Expand All @@ -78,21 +91,30 @@ Summary of the feature:
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.

Example of the optimizer settings for Adafactor with the fixed learning rate:
```
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```

### Format of Textual Inversion embeddings

```python
from safetensors.torch import save_file

state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```

### TODO

- [ ] Support Textual Inversion training.
- [ ] Support conversion of Diffusers SDXL models.
- [ ] Support `--weighted_captions` option.
- [ ] Change `--output_config` option to continue the training.
- [ ] Extend `--full_bf16` for all the scripts.
- [x] Support Textual Inversion training.

## About requirements.txt

Expand Down
41 changes: 39 additions & 2 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp

class WrapperTokenizer:
# open clipのtokenizerをHuggingFaceのtokenizerと同じ形で使えるようにする
# make open clip tokenizer compatible with HuggingFace tokenizer
def __init__(self):
open_clip_tokenizer = open_clip.tokenizer._tokenizer
self.model_max_length = 77
self.bos_token_id = open_clip_tokenizer.all_special_ids[0]
self.eos_token_id = open_clip_tokenizer.all_special_ids[1]
self.pad_token_id = 0 # 結果から推定している
self.pad_token_id = 0 # 結果から推定している assumption from result

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.tokenize(*args, **kwds)
Expand All @@ -107,6 +108,42 @@ def tokenize(self, text, padding=False, truncation=None, max_length=None, return
input_ids = input_ids[: eos_index + 1] # include eos
return SimpleNamespace(**{"input_ids": input_ids})

# for Textual Inversion
# わりと面倒くさいな……これWeb UIとかでどうするんだろう / this is a bit annoying... how to do this in Web UI?

def encode(self, text, add_special_tokens=False):
assert not add_special_tokens
input_ids = open_clip.tokenizer._tokenizer.encode(text)
return input_ids

def add_tokens(self, new_tokens):
tokens_to_add = []
for token in new_tokens:
token = token.lower()
if token + "</w>" not in open_clip.tokenizer._tokenizer.encoder:
tokens_to_add.append(token)

# open clipのtokenizerに直接追加する / add tokens to open clip tokenizer
for token in tokens_to_add:
open_clip.tokenizer._tokenizer.encoder[token + "</w>"] = len(open_clip.tokenizer._tokenizer.encoder)
open_clip.tokenizer._tokenizer.decoder[len(open_clip.tokenizer._tokenizer.decoder)] = token + "</w>"
open_clip.tokenizer._tokenizer.vocab_size += 1

# open clipのtokenizerのcacheに直接設定することで、bpeとかいうやつに含まれていなくてもtokenizeできるようにする
# めちゃくちゃ乱暴なので、open clipのtokenizerの仕様が変わったら動かなくなる
# set cache of open clip tokenizer directly to enable tokenization even if the token is not included in bpe
# this is very rough, so it will not work if the specification of open clip tokenizer changes
open_clip.tokenizer._tokenizer.cache[token] = token + "</w>"

return len(tokens_to_add)

def convert_tokens_to_ids(self, tokens):
input_ids = [open_clip.tokenizer._tokenizer.encoder[token + "</w>"] for token in tokens]
return input_ids

def __len__(self):
return open_clip.tokenizer._tokenizer.vocab_size


def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
Expand Down Expand Up @@ -392,7 +429,7 @@ def verify_sdxl_training_args(args: argparse.Namespace):
print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")

assert (
not args.weighted_captions
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"


Expand Down
33 changes: 22 additions & 11 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def __init__(
self.scheduler = scheduler
self.safety_checker = None

# Textual Inversion # not tested yet
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
self.token_replacements_list.append({})
Expand All @@ -341,6 +341,10 @@ def get_token_replacer(self, tokenizer):
token_replacements = self.token_replacements_list[tokenizer_index]

def replace_tokens(tokens):
# print("replace_tokens", tokens, "=>", token_replacements)
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()

new_tokens = []
for token in tokens:
if token in token_replacements:
Expand Down Expand Up @@ -1594,19 +1598,26 @@ def __getattr__(self, item):

if "string_to_param" in data:
data = data["string_to_param"]
embeds1 = data["clip_l"]
embeds2 = data["clip_g"]

embeds1 = data["clip_l"] # text encoder 1
embeds2 = data["clip_g"] # text encoder 2

num_vectors_per_token = embeds1.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]

# remove non-alphabet characters to avoid splitting by tokenizer
# TODO make random alphabet string
token_string = "".join([c for c in token_string if c.isalpha()])

token_strings = [token_string] + [f"{token_string}{chr(ord('a') + i)}" for i in range(num_vectors_per_token - 1)]

# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens1 = tokenizer1.add_tokens(token_strings)
num_added_tokens2 = tokenizer2.add_tokens(token_strings) # not working now
assert (
num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
num_added_tokens2 = tokenizer2.add_tokens(token_strings)
assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, (
f"tokenizer has same word to token string (filename). characters except alphabet are removed: {embeds_file}"
+ f" / 指定した名前(ファイル名)のトークンが既に存在します。アルファベット以外の文字は削除されます: {embeds_file}"
)

token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
Expand All @@ -1617,11 +1628,11 @@ def __getattr__(self, item):
assert (
min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1
), f"token ids2 is not ordered"
assert len(tokenizer1) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer2)}"
assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}"
assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}"

if num_vectors_per_token > 1:
pipe.add_token_replacement(0, token_ids1[0], token_ids1)
pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ...
pipe.add_token_replacement(1, token_ids2[0], token_ids2)

token_ids_embeds1.append((token_ids1, embeds1))
Expand Down
142 changes: 142 additions & 0 deletions sdxl_train_textual_inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import argparse
import os

import regex
import torch
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util

import train_textual_inversion


class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)

def load_target_model(self, args, weight_dtype, accelerator):
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)

self.load_stable_diffusion_format = load_stable_diffusion_format
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info

return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet

def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer

def assert_token_string(self, token_string, tokenizers):
# tokenizer 1 is seems to be ok

# count words for token string: regular expression from open_clip
pat = regex.compile(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", regex.IGNORECASE)
words = regex.findall(pat, token_string)
word_count = len(words)
assert word_count == 1, (
f"token string {token_string} contain {word_count} words, please don't use digits, punctuation, or special characters"
+ f" / トークン文字列 {token_string} には{word_count}個の単語が含まれています。数字、句読点、特殊文字は使用しないでください"
)

def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = sdxl_train_util.get_hidden_states(
args,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2

def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)

# concat embeddings
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)

noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred

def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
)

def save_weights(self, file, updated_embs, save_dtype):
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}

if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v

if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file

save_file(state_dict, file)
else:
torch.save(state_dict, file)

def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file

data = load_file(file)
else:
data = torch.load(file, map_location="cpu")

emb_l = data.get("clib_l", None) # ViT-L text encoder 1
emb_g = data.get("clib_g", None) # BiG-G text encoder 2

assert (
emb_l is not None or emb_g is not None
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"

return [emb_l, emb_g]


def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
return parser


if __name__ == "__main__":
parser = setup_parser()

args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)

trainer = SdxlTextualInversionTrainer()
trainer.train(args)
Loading

0 comments on commit 1ba606c

Please sign in to comment.