diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index fb7cd75861f8..8deda0d4acbf 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1] + + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) + return x_out.reshape(*x.shape).type_as(x) def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 2dac5980cd99..54f61a80794e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -237,6 +237,7 @@ def forward( freqs, transformer_options=transformer_options) x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)