11# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main
22import numpy as np
33import torch
4+ from typing_extensions import override
5+ from comfy_api .latest import ComfyExtension , io
46
57def loglinear_interp (t_steps , num_steps ):
68 """
@@ -333,25 +335,28 @@ def loglinear_interp(t_steps, num_steps):
333335 ],
334336}
335337
336- class GITSScheduler :
338+ class GITSScheduler ( io . ComfyNode ) :
337339 @classmethod
338- def INPUT_TYPES (s ):
339- return {"required" :
340- {"coeff" : ("FLOAT" , {"default" : 1.20 , "min" : 0.80 , "max" : 1.50 , "step" : 0.05 }),
341- "steps" : ("INT" , {"default" : 10 , "min" : 2 , "max" : 1000 }),
342- "denoise" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 1.0 , "step" : 0.01 }),
343- }
344- }
345- RETURN_TYPES = ("SIGMAS" ,)
346- CATEGORY = "sampling/custom_sampling/schedulers"
340+ def define_schema (cls ):
341+ return io .Schema (
342+ node_id = "GITSScheduler" ,
343+ category = "sampling/custom_sampling/schedulers" ,
344+ inputs = [
345+ io .Float .Input ("coeff" , default = 1.20 , min = 0.80 , max = 1.50 , step = 0.05 ),
346+ io .Int .Input ("steps" , default = 10 , min = 2 , max = 1000 ),
347+ io .Float .Input ("denoise" , default = 1.0 , min = 0.0 , max = 1.0 , step = 0.01 ),
348+ ],
349+ outputs = [
350+ io .Sigmas .Output (),
351+ ],
352+ )
347353
348- FUNCTION = "get_sigmas"
349-
350- def get_sigmas (self , coeff , steps , denoise ):
354+ @classmethod
355+ def execute (cls , coeff , steps , denoise ):
351356 total_steps = steps
352357 if denoise < 1.0 :
353358 if denoise <= 0.0 :
354- return (torch .FloatTensor ([]), )
359+ return io . NodeOutput (torch .FloatTensor ([]))
355360 total_steps = round (steps * denoise )
356361
357362 if steps <= 20 :
@@ -362,8 +367,16 @@ def get_sigmas(self, coeff, steps, denoise):
362367
363368 sigmas = sigmas [- (total_steps + 1 ):]
364369 sigmas [- 1 ] = 0
365- return (torch .FloatTensor (sigmas ), )
370+ return io . NodeOutput (torch .FloatTensor (sigmas ))
366371
367- NODE_CLASS_MAPPINGS = {
368- "GITSScheduler" : GITSScheduler ,
369- }
372+
373+ class GITSSchedulerExtension (ComfyExtension ):
374+ @override
375+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
376+ return [
377+ GITSScheduler ,
378+ ]
379+
380+
381+ async def comfy_entrypoint () -> GITSSchedulerExtension :
382+ return GITSSchedulerExtension ()
0 commit comments