Skip to content

Commit

Permalink
画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (kohya-ss#1223)
Browse files Browse the repository at this point in the history
* Add alpha_mask parameter and apply masked loss

* Fix type hint in trim_and_resize_if_required function

* Refactor code to use keyword arguments in train_util.py

* Fix alpha mask flipping logic

* Fix alpha mask initialization

* Fix alpha_mask transformation

* Cache alpha_mask

* Update alpha_masks to be on CPU

* Set flipped_alpha_masks to Null if option disabled

* Check if alpha_mask is None

* Set alpha_mask to None if option disabled

* Add description of alpha_mask option to docs
  • Loading branch information
u-haru authored May 19, 2024
1 parent febc5c5 commit db67529
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 129 deletions.
2 changes: 2 additions & 0 deletions docs/train_network_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
* `--network_args`
* 複数の引数を指定できます。後述します。
* `--alpha_mask`
* 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)

`--network_train_unet_only``--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。

Expand Down
2 changes: 2 additions & 0 deletions docs/train_network_README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中
* 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。
* `--network_args`
* 可以指定多个参数。将在下面详细说明。
* `--alpha_mask`
* 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)

当未指定`--network_train_unet_only``--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。

Expand Down
2 changes: 2 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
alpha_mask: bool = False


@dataclass
Expand Down Expand Up @@ -538,6 +539,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
"""
),
" ",
Expand Down
5 changes: 3 additions & 2 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def apply_masked_loss(loss, batch):
def apply_masked_loss(loss, mask_image):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image.to(dtype=loss.dtype)

# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
Expand Down
Loading

0 comments on commit db67529

Please sign in to comment.