diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index 46085117c..55c80c4b0 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。 * `--network_args` * 複数の引数を指定できます。後述します。 +* `--alpha_mask` + * 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 diff --git a/docs/train_network_README-zh.md b/docs/train_network_README-zh.md index ed7a0c4ef..830014f72 100644 --- a/docs/train_network_README-zh.md +++ b/docs/train_network_README-zh.md @@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中 * 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。 * `--network_args` * 可以指定多个参数。将在下面详细说明。 +* `--alpha_mask` + * 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) 当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。 diff --git a/library/config_util.py b/library/config_util.py index 59f5f86d2..82baab83e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -78,6 +78,7 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + alpha_mask: bool = False @dataclass @@ -538,6 +539,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, + alpha_mask: {subset.alpha_mask}, """ ), " ", diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 406e0e36e..fad127405 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, mask_image): # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + # mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image.to(dtype=loss.dtype) # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") diff --git a/library/train_util.py b/library/train_util.py index 410471470..20f8055dc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -159,6 +159,9 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None + self.alpha_mask_flipped: Optional[torch.Tensor] = None + self.use_alpha_mask: bool = False class BucketManager: @@ -379,6 +382,7 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + alpha_mask: bool, ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -403,6 +407,7 @@ def __init__( self.img_count = 0 + self.alpha_mask = alpha_mask class DreamBoothSubset(BaseSubset): def __init__( @@ -412,47 +417,13 @@ def __init__( class_tokens: Optional[str], caption_extension: str, cache_info: bool, - num_repeats, - shuffle_caption, - caption_separator: str, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.is_reg = is_reg @@ -473,47 +444,13 @@ def __init__( self, image_dir, metadata_file: str, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.metadata_file = metadata_file @@ -531,47 +468,13 @@ def __init__( conditioning_data_dir: str, caption_extension: str, cache_info: bool, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.conditioning_data_dir = conditioning_data_dir @@ -985,6 +888,8 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] + info.use_alpha_mask = subset.alpha_mask + if info.latents_npz is not None: # fine tuning dataset continue @@ -1088,8 +993,8 @@ def cache_text_encoder_outputs( def get_image_size(self, image_path): return imagesize.get(image_path) - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = load_image(image_path) + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): + img = load_image(image_path, alpha_mask) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -1166,6 +1071,7 @@ def __getitem__(self, index): input_ids_list = [] input_ids2_list = [] latents_list = [] + alpha_mask_list = [] images = [] original_sizes_hw = [] crop_top_lefts = [] @@ -1190,21 +1096,27 @@ def __getitem__(self, index): crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents + alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped - + alpha_mask = image_info.alpha_mask_flipped + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents + alpha_mask = flipped_alpha_mask del flipped_latents + del flipped_alpha_mask latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1241,11 +1153,22 @@ def __getitem__(self, index): if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [W,H] + else: + alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] + alpha_mask = transforms.ToTensor()(alpha_mask) + else: + alpha_mask = None + img = img[:, :, :3] # remove alpha channel + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる images.append(image) latents_list.append(latents) + alpha_mask_list.append(alpha_mask) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1348,6 +1271,8 @@ def __getitem__(self, index): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -2145,7 +2070,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2154,13 +2079,19 @@ def load_latents_from_disk( original_size = npz["original_size"].tolist() crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + if flipped_alpha_mask is not None: + kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2349,17 +2280,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path): +def load_image(image_path, alpha=False): image = Image.open(image_path) if not image.mode == "RGB": - image = image.convert("RGB") + if alpha: + image = image.convert("RGBA") + else: + image = image.convert("RGB") img = np.array(image, np.uint8) return img # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize @@ -2403,10 +2337,18 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] + alpha_masks = [] for info in image_infos: - image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + if info.use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [W,H] + image = image[:, :, :3] + else: + alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] + alpha_masks.append(transforms.ToTensor()(alpha_mask)) image = IMAGE_TRANSFORMS(image) images.append(image) @@ -2419,25 +2361,37 @@ def cache_batch_latents( with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + if info.use_alpha_mask: + alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu") + else: + alpha_masks = [None] * len(image_infos) + flipped_alpha_masks = [None] * len(image_infos) + if flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) with torch.no_grad(): flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + if info.use_alpha_mask: + flipped_alpha_masks = torch.flip(alpha_masks, dims=[3]) else: flipped_latents = [None] * len(latents) + flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + info.alpha_mask_flipped = flipped_alpha_mask + if not HIGH_VRAM: clean_memory_on_device(vae.device) @@ -3683,6 +3637,11 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--alpha_mask", + action="store_true", + help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する", + ) parser.add_argument( "--dataset_class", diff --git a/sdxl_train.py b/sdxl_train.py index 7c71a5133..dcd06766b 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -712,7 +712,9 @@ def optimizer_hook(parameter: torch.Tensor): noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: diff --git a/train_db.py b/train_db.py index a5408cd3d..c46900006 100644 --- a/train_db.py +++ b/train_db.py @@ -360,7 +360,9 @@ def train(args): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_network.py b/train_network.py index 38e4888e8..cd1677ad2 100644 --- a/train_network.py +++ b/train_network.py @@ -903,7 +903,9 @@ def remove_model(old_ckpt_name): noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 184607d1d..a9c2a1094 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -590,7 +590,9 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8eed00fa1..959839cbb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -475,7 +475,9 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight