Skip to content

Commit

Permalink
Merge pull request #628 from KohakuBlueleaf/full_bf16
Browse files Browse the repository at this point in the history
Full bf16 support
  • Loading branch information
kohya-ss authored Jul 9, 2023
2 parents 256ff5b + d974959 commit 3579b45
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2422,6 +2422,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
)
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する")
parser.add_argument(
"--clip_skip",
type=int,
Expand Down
8 changes: 8 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
Expand Down

0 comments on commit 3579b45

Please sign in to comment.