@@ -1265,7 +1265,8 @@ def __getitem__(self, index):
1265
1265
if subset .alpha_mask :
1266
1266
if img .shape [2 ] == 4 :
1267
1267
alpha_mask = img [:, :, 3 ] # [H,W]
1268
- alpha_mask = transforms .ToTensor ()(alpha_mask ) # 0-255 -> 0-1
1268
+ alpha_mask = alpha_mask .astype (np .float32 ) / 255.0 # 0.0~1.0
1269
+ alpha_mask = torch .FloatTensor (alpha_mask )
1269
1270
else :
1270
1271
alpha_mask = torch .ones ((img .shape [0 ], img .shape [1 ]), dtype = torch .float32 )
1271
1272
else :
@@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
2211
2212
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
2212
2213
def load_latents_from_disk (
2213
2214
npz_path ,
2214
- ) -> Tuple [Optional [torch . Tensor ], Optional [List [int ]], Optional [List [int ]], Optional [np .ndarray ], Optional [np .ndarray ]]:
2215
+ ) -> Tuple [Optional [np . ndarray ], Optional [List [int ]], Optional [List [int ]], Optional [np .ndarray ], Optional [np .ndarray ]]:
2215
2216
npz = np .load (npz_path )
2216
2217
if "latents" not in npz :
2217
2218
raise ValueError (f"error: npz is old format. please re-generate { npz_path } " )
@@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli
2229
2230
if flipped_latents_tensor is not None :
2230
2231
kwargs ["latents_flipped" ] = flipped_latents_tensor .float ().cpu ().numpy ()
2231
2232
if alpha_mask is not None :
2232
- kwargs ["alpha_mask" ] = alpha_mask # ndarray
2233
+ kwargs ["alpha_mask" ] = alpha_mask . float (). cpu (). numpy ()
2233
2234
np .savez (
2234
2235
npz_path ,
2235
2236
latents = latents_tensor .float ().cpu ().numpy (),
@@ -2496,8 +2497,9 @@ def cache_batch_latents(
2496
2497
if image .shape [2 ] == 4 :
2497
2498
alpha_mask = image [:, :, 3 ] # [H,W]
2498
2499
alpha_mask = alpha_mask .astype (np .float32 ) / 255.0
2500
+ alpha_mask = torch .FloatTensor (alpha_mask ) # [H,W]
2499
2501
else :
2500
- alpha_mask = np .ones_like (image [:, :, 0 ], dtype = np .float32 )
2502
+ alpha_mask = torch .ones_like (image [:, :, 0 ], dtype = torch .float32 ) # [H,W]
2501
2503
else :
2502
2504
alpha_mask = None
2503
2505
alpha_masks .append (alpha_mask )
0 commit comments