Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tiled decoding works create weird seams #8

Closed
Isotr0py opened this issue Aug 12, 2023 · 3 comments
Closed

Tiled decoding works create weird seams #8

Isotr0py opened this issue Aug 12, 2023 · 3 comments

Comments

@Isotr0py
Copy link

Isotr0py commented Aug 12, 2023

When I split decoding works into tiles without overlap, there are some strange weird seams on the bounds of each tile in decoded image.

However, according to the readme, it seems that weird seams shouldn't be created since TAESD has a bounded receptive field.

Though it can be solved by splitting tiles with overlap, I wonder whether it's a bug.

Here is the tiling-decoding code modified from taesd.py:

def tiled_decode(decoder, x: torch.FloatTensor) -> torch.FloatTensor:
    tile_latent_min_size = 32
    # split x into tiles
    tiles = list(x.split(tile_latent_min_size, dim=2))
    tiles = [list(tile.split(tile_latent_min_size, dim=3)) for tile in tiles]

    # decode each tiles
    for i, row in enumerate(tiles):
        for j, tile in enumerate(row):
            tiles[i][j] = decoder(tile)

    # merge tiles
    tiles = [torch.cat(tile, dim=3) for tile in tiles]
    return torch.cat(tiles, dim=2)


@torch.no_grad()
def main():
    from PIL import Image
    import sys
    import torchvision.transforms.functional as TF
    dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    print("Using device", dev)
    taesd = TAESD().to(dev)
    for im_path in sys.argv[1:]:
        im = TF.to_tensor(Image.open(im_path).convert("RGB")).unsqueeze(0).to(dev)

        # encode image, quantize, and save to file
        im_enc = taesd.scale_latents(taesd.encoder(im)).mul_(255).round_().byte()
        enc_path = im_path + ".encoded.png"
        TF.to_pil_image(im_enc[0]).save(enc_path)
        print(f"Encoded {im_path} to {enc_path}")

        # load the saved file, dequantize, and decode
        im_enc = taesd.unscale_latents(TF.to_tensor(Image.open(enc_path)).unsqueeze(0).to(dev))
        im_dec = tiled_decode(taesd.decoder, im_enc).clamp(0, 1)
        dec_path = im_path + ".decoded.png"
        print(f"Decoded {enc_path} to {dec_path}")
        TF.to_pil_image(im_dec[0]).save(dec_path)

original image

cat_512

decoded

cat_512 jpg decoded

tiling decoded

cat_512_tiled_decoded

@madebyollin
Copy link
Owner

Hmm, I think the README was overselling things (my fault!).

Here's what the receptive field looks like for TAESD decoding:

Receptive field test code
@torch.no_grad()
def test_receptive_field():
    from PIL import Image
    import torchvision.transforms.functional as TF
    dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    taesd = TAESD().to(dev)
    
    im = torch.zeros(1, 4, 128, 128, device=dev)
    im_dec = taesd.decoder(im)
    
    im[..., 64, 64] = 1
    im_dec_2 = taesd.decoder(im)
    
    display(TF.to_pil_image((im_dec != im_dec_2).float()[0]))
    
test_receptive_field()

TAESD Receptive Field
image

vs. the SD VAE:

SD-VAE Receptive Field
image

So, TAESD receptive field is bounded, and SD-VAE receptive field isn't. For tiled decoding to be perfect (identical to non-tiled), you need enough tile overlap to cover the entire receptive field. So with enough tile overlap, TAESD can give identical results to non-tiled decoding, whereas SD-VAE (in principle) will always have tiling artifacts... or so I figured, while writing the README.

In practice, it looks like you can get away with a lot less than full-receptive-field tile overlap for both TAESD and SD-VAE - so the "bounded-but-large receptive field" vs "infinite receptive field" distinction doesn't have much practical benefit.

Here's TAESD tiled decode output:

Tiled decoding test code ``` def tiled_decode_with_overlap( decoder: torch.nn.Module, x: torch.FloatTensor, tile_size: int = 32, decoder_spatial_scale_factor: int = 8, ) -> torch.FloatTensor: # scale of decoder output relative to input sf = decoder_spatial_scale_factor # number of tiles - plus one, for overlap nti = math.ceil(x.shape[-2] / tile_size) + 1 ntj = math.ceil(x.shape[-1] / tile_size) + 1 # number of input pixels to traverse between tiles sti = (x.shape[-2] - tile_size) / (nti - 1) stj = (x.shape[-1] - tile_size) / (ntj - 1) # number of pixels to blend blend_i = int(tile_size - sti) blend_j = int(tile_size - stj) # mask for blending blend_masks = torch.stack(torch.meshgrid([torch.arange(tile_size*sf) / (blend_i*sf-1), torch.arange(tile_size*sf) / (blend_j*sf-1)]), 0) blend_masks = blend_masks.clamp(0, 1).to(x.device) # output array out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-2] * sf, device=x.device) for i in range(nti): for j in range(ntj): ti, tj = round(sti * i), round(stj * j) # tile in / out regions tile_in = x[..., ti:ti+tile_size, tj:tj+tile_size] tile_out = out[..., ti*sf:(ti+tile_size)*sf, tj*sf:(tj+tile_size)*sf] # tile result tile = decoder(tile_in) # blend tile result into output blend_mask_i = 1 if i == 0 else blend_masks[0] blend_mask_j = 1 if j == 0 else blend_masks[1] blend_mask = blend_mask_i * blend_mask_j tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) return out ```

TAESD Tiled Decode, 3x3 32x32 tiles, 16 (latent) pixels of overlap
tiled_decoding

image

vs. SD-VAE tiled decode output:

SD-VAE Tiled Decode, 3x3 32x32 tiles, 16 (latent) pixels of overlap

tiled_decoding

image

To me, they both look free of perceptible tiling artifacts once you add the overlap, so I'll update the README.

@Isotr0py
Copy link
Author

Got it! Thx for the detailed instruction!

@swamidass
Copy link

What about the encoder? Does it have any issues with seems? What is its receptive field?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants