Skip to content

Commit e5bab69

Browse files
committed
fix alpha mask without disk cache closes kohya-ss#1351, ref kohya-ss#1339
1 parent 0d96e10 commit e5bab69

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

library/train_util.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,8 @@ def __getitem__(self, index):
12651265
if subset.alpha_mask:
12661266
if img.shape[2] == 4:
12671267
alpha_mask = img[:, :, 3] # [H,W]
1268-
alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1
1268+
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
1269+
alpha_mask = torch.FloatTensor(alpha_mask)
12691270
else:
12701271
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
12711272
else:
@@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
22112212
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
22122213
def load_latents_from_disk(
22132214
npz_path,
2214-
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
2215+
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
22152216
npz = np.load(npz_path)
22162217
if "latents" not in npz:
22172218
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
@@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli
22292230
if flipped_latents_tensor is not None:
22302231
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
22312232
if alpha_mask is not None:
2232-
kwargs["alpha_mask"] = alpha_mask # ndarray
2233+
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
22332234
np.savez(
22342235
npz_path,
22352236
latents=latents_tensor.float().cpu().numpy(),
@@ -2496,8 +2497,9 @@ def cache_batch_latents(
24962497
if image.shape[2] == 4:
24972498
alpha_mask = image[:, :, 3] # [H,W]
24982499
alpha_mask = alpha_mask.astype(np.float32) / 255.0
2500+
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
24992501
else:
2500-
alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32)
2502+
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
25012503
else:
25022504
alpha_mask = None
25032505
alpha_masks.append(alpha_mask)

0 commit comments

Comments
 (0)