Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TAESD(or more) options for all the VAE encode/decode operation #12311

Merged
merged 14 commits into from
Aug 5, 2023
8 changes: 8 additions & 0 deletions modules/generation_parameters_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ def parse_generation_parameters(x: str):
if "Schedule rho" not in res:
res["Schedule rho"] = 0

if "VAE Encoder" not in res:
res["VAE Encoder"] = "Full"

if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"

return res


Expand All @@ -332,6 +338,8 @@ def parse_generation_parameters(x: str):
('RNG', 'randn_source'),
('NGMS', 's_min_uncond'),
('Pad conds', 'pad_cond_uncond'),
('VAE Encoder', 'sd_vae_encode_method'),
('VAE Decoder', 'sd_vae_decode_method'),
]


Expand Down
19 changes: 9 additions & 10 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
Expand All @@ -30,7 +31,6 @@
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType

decode_first_stage = sd_samplers_common.decode_first_stage

# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
Expand Down Expand Up @@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):

# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))

# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
Expand Down Expand Up @@ -203,7 +203,7 @@ def depth2img_image_conditioning(self, source_image):
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)

conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
Expand All @@ -216,7 +216,7 @@ def depth2img_image_conditioning(self, source_image):
return conditioning

def edit_image_conditioning(self, source_image):
conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))

return conditioning_image

Expand Down Expand Up @@ -795,6 +795,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

x_samples_ddim = torch.stack(x_samples_ddim).float()
Expand Down Expand Up @@ -1135,11 +1136,10 @@ def save_intermediate(image, index):
batch_images.append(image)

decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device)
decoded_samples = 2. * decoded_samples - 1.
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)

samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))

image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)

Expand Down Expand Up @@ -1374,10 +1374,9 @@ def init(self, all_prompts, all_seeds, all_subseeds):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(shared.device, dtype=devices.dtype_vae)

self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
devices.torch_gc()

if self.resize_mode == 3:
Expand Down
44 changes: 36 additions & 8 deletions modules/sd_samplers_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@ def setup_img2img_steps(p, steps=None):
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}


def single_sample_to_image(sample, approximation=None):
def samples_to_images_tensor(sample, approximation=None, model=None):
'''latents -> images [-1, 1]'''
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)

if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
x_sample = sample * 1.5
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else:
x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
if model is None:
model = shared.sd_model
x_sample = model.decode_first_stage(sample)

return x_sample


def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5

x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
Expand All @@ -45,9 +55,9 @@ def single_sample_to_image(sample, approximation=None):


def decode_first_stage(model, x):
x = model.decode_first_stage(x.to(devices.dtype_vae))

return x
x = x.to(devices.dtype_vae)
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
return samples_to_images_tensor(x, approx_index, model)


def sample_to_image(samples, index=0, approximation=None):
Expand All @@ -58,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])


def images_tensor_to_samples(image, approximation=None, model=None):
'''image[0, 1] -> latent'''
if approximation is None:
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)

if approximation == 3:
image = image.to(devices.device, devices.dtype)
x_latent = sd_vae_taesd.encoder_model()(image)
else:
if model is None:
model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))

return x_latent


def store_latent(decoded):
state.current_latent = decoded

Expand Down
2 changes: 1 addition & 1 deletion modules/sd_vae_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ def cheap_approximation(sample):

coefs = torch.tensor(coeffs).to(sample.device)

x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)

return x_sample
52 changes: 44 additions & 8 deletions modules/sd_vae_taesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,17 @@ def decoder():
)


class TAESD(nn.Module):
def encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)


class TAESDDecoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5

Expand All @@ -55,21 +65,28 @@ def __init__(self, decoder_path="taesd_decoder.pth"):
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))

@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

class TAESDEncoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5

def __init__(self, encoder_path="taesd_encoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = encoder()
self.encoder.load_state_dict(
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))


def download_model(model_path, model_url):
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)

print(f'Downloading TAESD decoder to: {model_path}')
print(f'Downloading TAESD model to: {model_path}')
torch.hub.download_url_to_file(model_url, model_path)


def model():
def decoder_model():
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)

Expand All @@ -78,11 +95,30 @@ def model():
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)

if os.path.exists(model_path):
loaded_model = TAESD(model_path)
loaded_model = TAESDDecoder(model_path)
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
else:
raise FileNotFoundError('TAESD model not found')

return loaded_model.decoder


def encoder_model():
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)

if loaded_model is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)

if os.path.exists(model_path):
loaded_model = TAESDEncoder(model_path)
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
else:
raise FileNotFoundError('TAESD model not found')

return loaded_model.encoder
2 changes: 2 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ def list_samplers():
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
}))

options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
Expand Down