Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions comfy/ldm/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def block_wrap(args):
x = self.unpatchify(x, grid_sizes)
return x

def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
Expand All @@ -601,10 +601,22 @@ def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=No
if steps_w is None:
steps_w = w_len

h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = t_len * rope_options.get("scale_t", 1.0)
h_len = h_len * rope_options.get("scale_y", 1.0)
w_len = w_len * rope_options.get("scale_x", 1.0)

t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)

img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])

freqs = self.rope_embedder(img_ids).movedim(1, 2)
Expand All @@ -630,7 +642,7 @@ def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, tr
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1

freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]

def unpatchify(self, x, grid_sizes):
Expand Down
13 changes: 13 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,19 @@ def set_model_double_block_patch(self, patch):
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")

def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t

rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t

self.model_options["transformer_options"]["rope_options"] = rope_options


def add_object_patch(self, name, obj):
self.object_patches[name] = obj

Expand Down
47 changes: 47 additions & 0 deletions comfy_extras/nodes_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override


class ScaleROPE(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ScaleROPE",
category="advanced/model_patches",
description="Scale and shift the ROPE of the model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),

io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),

io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),


],
outputs=[
io.Model.Output(),
],
)

@classmethod
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
m = model.clone()
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
return io.NodeOutput(m)


class RopeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ScaleROPE
]


async def comfy_entrypoint() -> RopeExtension:
return RopeExtension()
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
]

import_failed = []
Expand Down