Skip to content

Commit

Permalink
[Cog] some minor fixes and nits (#9466)
Browse files Browse the repository at this point in the history
* fix positional arguments in check_inputs().

* add video and latetns to check_inputs().

* prep latents_in_channels.

* quality

* multiple fixes.

* fix
  • Loading branch information
sayakpaul authored Sep 23, 2024
1 parent aa73072 commit ba5af5a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 61 deletions.
26 changes: 14 additions & 12 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ def __init__(
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)

self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)

Expand Down Expand Up @@ -317,18 +320,19 @@ def encode_prompt(
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Expand All @@ -341,7 +345,7 @@ def prepare_latents(

def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
latents = 1 / self.vae_scaling_factor_image * latents

frames = self.vae.decode(latents).sample
return frames
Expand Down Expand Up @@ -510,10 +514,10 @@ def __call__(
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
Expand Down Expand Up @@ -587,8 +591,6 @@ def __call__(
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1

# 1. Check inputs. Raise error if not correct
Expand Down
52 changes: 25 additions & 27 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def __init__(
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)

self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)

Expand Down Expand Up @@ -348,6 +351,12 @@ def prepare_latents(
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
Expand All @@ -357,12 +366,6 @@ def prepare_latents(
width // self.vae_scale_factor_spatial,
)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

image = image.unsqueeze(2) # [B, C, F, H, W]

if isinstance(generator, list):
Expand All @@ -373,7 +376,7 @@ def prepare_latents(
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]

image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
image_latents = self.vae.config.scaling_factor * image_latents
image_latents = self.vae_scaling_factor_image * image_latents

padding_shape = (
batch_size,
Expand All @@ -397,7 +400,7 @@ def prepare_latents(
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
latents = 1 / self.vae_scaling_factor_image * latents

frames = self.vae.decode(latents).sample
return frames
Expand Down Expand Up @@ -438,7 +441,6 @@ def check_inputs(
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
video=None,
latents=None,
prompt_embeds=None,
negative_prompt_embeds=None,
Expand Down Expand Up @@ -494,9 +496,6 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")

# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
Expand Down Expand Up @@ -584,18 +583,18 @@ def __call__(
Args:
image (`PipelineImageInput`):
The input video to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
Expand Down Expand Up @@ -665,20 +664,19 @@ def __call__(
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1

# 1. Check inputs. Raise error if not correct
self.check_inputs(
image,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
image=image,
prompt=prompt,
height=height,
width=width,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
Expand Down
48 changes: 26 additions & 22 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,16 @@ def __init__(
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)

self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.vae_scaling_factor_image = (
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
)

self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)

Expand Down Expand Up @@ -351,6 +355,12 @@ def prepare_latents(
latents: Optional[torch.Tensor] = None,
timestep: Optional[torch.Tensor] = None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)

shape = (
Expand All @@ -361,12 +371,6 @@ def prepare_latents(
width // self.vae_scale_factor_spatial,
)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
if isinstance(generator, list):
if len(generator) != batch_size:
Expand All @@ -382,7 +386,7 @@ def prepare_latents(
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]

init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
init_latents = self.vae.config.scaling_factor * init_latents
init_latents = self.vae_scaling_factor_image * init_latents

noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.add_noise(init_latents, noise, timestep)
Expand All @@ -396,7 +400,7 @@ def prepare_latents(
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
latents = 1 / self.vae_scaling_factor_image * latents

frames = self.vae.decode(latents).sample
return frames
Expand Down Expand Up @@ -589,10 +593,10 @@ def __call__(
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
Expand Down Expand Up @@ -658,20 +662,20 @@ def __call__(
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
strength,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
prompt=prompt,
height=height,
width=width,
strength=strength,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
video=video,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
Expand Down

0 comments on commit ba5af5a

Please sign in to comment.