Skip to content

Commit

Permalink
Fix HunyuanVideo produces NaN on PyTorch<2.5 (#10482)
Browse files Browse the repository at this point in the history
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
hlky and sayakpaul authored Jan 7, 2025
1 parent 03bcf5a commit 01bd796
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -713,15 +713,15 @@ def forward(
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N, N]
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]

effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length

for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N, N], for broadcasting across attention heads
attention_mask[i, : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads

# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down

1 comment on commit 01bd796

@ITerydh
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your fix seems to have caused pytorch 2.5.1 to be unavailable, there was no problem before it was fixed.

Please sign in to comment.