Skip to content

Commit

Permalink
fix to work sdxl state dict without logit_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jul 5, 2023
1 parent 3060eb5 commit 3d0375d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions library/sdxl_model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def convert_key(key):
new_sd["text_model.embeddings.position_ids"] = position_ids

# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint[SDXL_KEY_PREFIX + "logit_scale"]
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)

return new_sd, logit_scale

Expand Down Expand Up @@ -222,7 +222,7 @@ def convert_key(key):
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif ".token_embedding" in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif "text_projection" in key: # no dot in key
elif "text_projection" in key: # no dot in key
key = key.replace("text_projection.weight", "text_projection")
elif "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ln_final")
Expand Down Expand Up @@ -253,7 +253,8 @@ def convert_key(key):
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value

new_sd["logit_scale"] = logit_scale
if logit_scale is not None:
new_sd["logit_scale"] = logit_scale

return new_sd

Expand Down
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def train(args):
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)

# verify load/save model formats
if load_stable_diffusion_format:
Expand Down

0 comments on commit 3d0375d

Please sign in to comment.