-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[WIP] test prepare_latents for ltx0.95 #10976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
e202e46
267583a
16c1467
a098d94
d8bd10e
6f7e837
cbc035d
1fdebea
0cc1905
7c2151f
353728a
d85d21c
445cf58
fb46d21
64df9af
00e9670
ed2f7e3
b98d69c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -437,8 +437,8 @@ def check_inputs( | |
| ) | ||
|
|
||
| @staticmethod | ||
| # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents | ||
| def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: | ||
| # adapted from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents | ||
| def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1, device: torch.device = None) -> torch.Tensor: | ||
| # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. | ||
| # The patch dimensions are then permuted and collapsed into the channel dimension of shape: | ||
| # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). | ||
|
|
@@ -447,6 +447,17 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int | |
| post_patch_num_frames = num_frames // patch_size_t | ||
| post_patch_height = height // patch_size | ||
| post_patch_width = width // patch_size | ||
|
|
||
| latent_sample_coords = torch.meshgrid( | ||
| torch.arange(0, num_frames, patch_size_t, device=device), | ||
| torch.arange(0, height, patch_size, device=device), | ||
| torch.arange(0, width, patch_size, device=device), | ||
| indexing="ij", | ||
| ) | ||
| latent_sample_coords = torch.stack(latent_sample_coords, dim=0) | ||
| latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) | ||
| latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) | ||
|
|
||
| latents = latents.reshape( | ||
| batch_size, | ||
| -1, | ||
|
|
@@ -458,7 +469,7 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int | |
| patch_size, | ||
| ) | ||
| latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) | ||
| return latents | ||
| return latents, latent_coords | ||
|
|
||
| @staticmethod | ||
| # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents | ||
|
|
@@ -503,10 +514,10 @@ def _prepare_non_first_frame_conditioning( | |
| frame_index: int, | ||
| strength: float, | ||
| num_prefix_latent_frames: int = 2, | ||
| prefix_latents_mode: str = "soft", | ||
| prefix_latents_mode: str = "concat", | ||
| prefix_soft_conditioning_strength: float = 0.15, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| num_latent_frames = latents.size(2) | ||
| num_latent_frames = condition_latents.size(2) | ||
|
|
||
| if num_latent_frames < num_prefix_latent_frames: | ||
| raise ValueError( | ||
|
|
@@ -544,6 +555,25 @@ def _prepare_non_first_frame_conditioning( | |
|
|
||
| return latents, condition_latents, condition_latent_frames_mask | ||
|
|
||
| def trim_conditioning_sequence( | ||
| self, start_frame: int, sequence_num_frames: int, target_num_frames: int | ||
| ): | ||
| """ | ||
| Trim a conditioning sequence to the allowed number of frames. | ||
| Args: | ||
| start_frame (int): The target frame number of the first frame in the sequence. | ||
| sequence_num_frames (int): The number of frames in the sequence. | ||
| target_num_frames (int): The target number of frames in the generated video. | ||
| Returns: | ||
| int: updated sequence length | ||
| """ | ||
| scale_factor = self.vae_temporal_compression_ratio | ||
| num_frames = min(sequence_num_frames, target_num_frames - start_frame) | ||
| # Trim down to a multiple of temporal_scale_factor frames plus 1 | ||
| num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 | ||
| return num_frames | ||
|
|
||
|
|
||
| def prepare_latents( | ||
| self, | ||
| conditions: Union[LTXVideoCondition, List[LTXVideoCondition]], | ||
|
|
@@ -573,13 +603,17 @@ def prepare_latents( | |
| extra_conditioning_num_latents = ( | ||
| 0 # Number of extra conditioning latents added (should be removed before decoding) | ||
| ) | ||
| condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=dtype) | ||
| condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) | ||
|
|
||
| for condition in conditions: | ||
| if condition.image is not None: | ||
| data = self.video_processor.preprocess(condition.image, height, width).unsqueeze(2) | ||
| elif condition.video is not None: | ||
| data = self.video_processor.preprocess_video(condition.vide, height, width) | ||
| data = self.video_processor.preprocess_video(condition.video, height, width) | ||
| num_frames_input = data.size(2) | ||
| num_frames_output = self.trim_conditioning_sequence(condition.frame_index, num_frames_input, num_frames) | ||
| data = data[:, :, :num_frames_output] | ||
| data = data.to(device, dtype=dtype) | ||
| else: | ||
| raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") | ||
|
|
||
|
|
@@ -599,6 +633,7 @@ def prepare_latents( | |
| latents[:, :, :num_cond_frames], condition_latents, condition.strength | ||
| ) | ||
| condition_latent_frames_mask[:, :num_cond_frames] = condition.strength | ||
|
|
||
| else: | ||
| if num_data_frames > 1: | ||
| ( | ||
|
|
@@ -617,47 +652,39 @@ def prepare_latents( | |
| noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) | ||
| condition_latents = torch.lerp(noise, condition_latents, condition.strength) | ||
| c_nlf = condition_latents.shape[2] | ||
| condition_latents = self._pack_latents( | ||
| condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| condition_latents, rope_interpolation_scale = self._pack_latents( | ||
| condition_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
|
|
||
| rope_interpolation_scale = ( | ||
| rope_interpolation_scale * | ||
| torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] | ||
| ) | ||
| rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) | ||
| rope_interpolation_scale[:, 0] += condition.frame_index | ||
|
|
||
|
||
| conditioning_mask = torch.full( | ||
| condition_latents.shape[:2], condition.strength, device=device, dtype=dtype | ||
| ) | ||
|
|
||
| rope_interpolation_scale = [ | ||
| # TODO!!! This is incorrect: the frame index needs to added AFTER multiplying the interpolation | ||
| # scale with the grid. | ||
| (self.vae_temporal_compression_ratio + condition.frame_index) / frame_rate, | ||
| self.vae_spatial_compression_ratio, | ||
| self.vae_spatial_compression_ratio, | ||
| ] | ||
|
Comment on lines
-627
to
-633
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yiyixuxu Pardon my stupidity, but I can't seem to find if we're handling this In the original code, this is what I was meaning to handle: https://github.com/Lightricks/LTX-Video/blob/496dc5058f4408dcb777282f3fb6377fb2da08e6/ltx_video/pipelines/pipeline_ltx_video.py#L1285 |
||
| rope_interpolation_scale = ( | ||
| torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) | ||
| .view(-1, 1, 1, 1, 1) | ||
| .repeat(1, 1, c_nlf, latent_height, latent_width) | ||
| ) | ||
| extra_conditioning_num_latents += condition_latents.size(1) | ||
|
|
||
| extra_conditioning_latents.append(condition_latents) | ||
| extra_conditioning_rope_interpolation_scales.append(rope_interpolation_scale) | ||
| extra_conditioning_mask.append(conditioning_mask) | ||
|
|
||
| latents = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| latents, rope_interpolation_scale = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
| rope_interpolation_scale = [ | ||
| self.vae_temporal_compression_ratio / frame_rate, | ||
| self.vae_spatial_compression_ratio, | ||
| self.vae_spatial_compression_ratio, | ||
| ] | ||
| rope_interpolation_scale = ( | ||
| torch.tensor(rope_interpolation_scale, device=device, dtype=dtype) | ||
| .view(-1, 1, 1, 1, 1) | ||
| .repeat(1, 1, num_latent_frames, latent_height, latent_width) | ||
| conditioning_mask = condition_latent_frames_mask.gather( | ||
| 1, rope_interpolation_scale[:, 0] | ||
| ) | ||
| conditioning_mask = self._pack_latents( | ||
| conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
|
|
||
| rope_interpolation_scale = ( | ||
| rope_interpolation_scale | ||
| * torch.tensor([self.vae_temporal_compression_ratio, self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio], device=rope_interpolation_scale.device)[None, :, None] | ||
| ) | ||
| rope_interpolation_scale[:, 0] = (rope_interpolation_scale[:, 0] + 1 - self.vae_temporal_compression_ratio).clamp(min=0) | ||
|
|
||
| if len(extra_conditioning_latents) > 0: | ||
| latents = torch.cat([*extra_conditioning_latents, latents], dim=1) | ||
|
|
@@ -864,7 +891,7 @@ def __call__( | |
| frame_rate, | ||
| generator, | ||
| device, | ||
| torch.float32, | ||
| prompt_embeds.dtype, | ||
|
||
| ) | ||
| init_latents = latents.clone() | ||
|
|
||
|
|
@@ -955,8 +982,8 @@ def __call__( | |
| pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] | ||
|
|
||
| latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) | ||
| latents = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size | ||
| latents, _ = self._pack_latents( | ||
| latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, device | ||
| ) | ||
|
|
||
| if callback_on_step_end is not None: | ||
|
|
||

Uh oh!
There was an error while loading. Please reload this page.