Skip to content

Commit 09a52cd

Browse files
authored
convert nodes_gits.py to V3 schema (comfyanonymous#9949)
1 parent 713598b commit 09a52cd

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

comfy_extras/nodes_gits.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
22
import numpy as np
33
import torch
4+
from typing_extensions import override
5+
from comfy_api.latest import ComfyExtension, io
46

57
def loglinear_interp(t_steps, num_steps):
68
"""
@@ -333,25 +335,28 @@ def loglinear_interp(t_steps, num_steps):
333335
],
334336
}
335337

336-
class GITSScheduler:
338+
class GITSScheduler(io.ComfyNode):
337339
@classmethod
338-
def INPUT_TYPES(s):
339-
return {"required":
340-
{"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}),
341-
"steps": ("INT", {"default": 10, "min": 2, "max": 1000}),
342-
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
343-
}
344-
}
345-
RETURN_TYPES = ("SIGMAS",)
346-
CATEGORY = "sampling/custom_sampling/schedulers"
340+
def define_schema(cls):
341+
return io.Schema(
342+
node_id="GITSScheduler",
343+
category="sampling/custom_sampling/schedulers",
344+
inputs=[
345+
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05),
346+
io.Int.Input("steps", default=10, min=2, max=1000),
347+
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
348+
],
349+
outputs=[
350+
io.Sigmas.Output(),
351+
],
352+
)
347353

348-
FUNCTION = "get_sigmas"
349-
350-
def get_sigmas(self, coeff, steps, denoise):
354+
@classmethod
355+
def execute(cls, coeff, steps, denoise):
351356
total_steps = steps
352357
if denoise < 1.0:
353358
if denoise <= 0.0:
354-
return (torch.FloatTensor([]),)
359+
return io.NodeOutput(torch.FloatTensor([]))
355360
total_steps = round(steps * denoise)
356361

357362
if steps <= 20:
@@ -362,8 +367,16 @@ def get_sigmas(self, coeff, steps, denoise):
362367

363368
sigmas = sigmas[-(total_steps + 1):]
364369
sigmas[-1] = 0
365-
return (torch.FloatTensor(sigmas), )
370+
return io.NodeOutput(torch.FloatTensor(sigmas))
366371

367-
NODE_CLASS_MAPPINGS = {
368-
"GITSScheduler": GITSScheduler,
369-
}
372+
373+
class GITSSchedulerExtension(ComfyExtension):
374+
@override
375+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
376+
return [
377+
GITSScheduler,
378+
]
379+
380+
381+
async def comfy_entrypoint() -> GITSSchedulerExtension:
382+
return GITSSchedulerExtension()

0 commit comments

Comments
 (0)