1+ from typing_extensions import override
2+
13import torch
24
5+ from comfy_api .latest import ComfyExtension , io
6+
7+
38# https://github.com/WeichenFan/CFG-Zero-star
49def optimized_scale (positive , negative ):
510 positive_flat = positive .reshape (positive .shape [0 ], - 1 )
@@ -16,17 +21,20 @@ def optimized_scale(positive, negative):
1621
1722 return st_star .reshape ([positive .shape [0 ]] + [1 ] * (positive .ndim - 1 ))
1823
19- class CFGZeroStar :
24+ class CFGZeroStar (io .ComfyNode ):
25+ @classmethod
26+ def define_schema (cls ) -> io .Schema :
27+ return io .Schema (
28+ node_id = "CFGZeroStar" ,
29+ category = "advanced/guidance" ,
30+ inputs = [
31+ io .Model .Input ("model" ),
32+ ],
33+ outputs = [io .Model .Output (display_name = "patched_model" )],
34+ )
35+
2036 @classmethod
21- def INPUT_TYPES (s ):
22- return {"required" : {"model" : ("MODEL" ,),
23- }}
24- RETURN_TYPES = ("MODEL" ,)
25- RETURN_NAMES = ("patched_model" ,)
26- FUNCTION = "patch"
27- CATEGORY = "advanced/guidance"
28-
29- def patch (self , model ):
37+ def execute (cls , model ) -> io .NodeOutput :
3038 m = model .clone ()
3139 def cfg_zero_star (args ):
3240 guidance_scale = args ['cond_scale' ]
@@ -38,21 +46,24 @@ def cfg_zero_star(args):
3846
3947 return out + uncond_p * (alpha - 1.0 ) + guidance_scale * uncond_p * (1.0 - alpha )
4048 m .set_model_sampler_post_cfg_function (cfg_zero_star )
41- return ( m , )
49+ return io . NodeOutput ( m )
4250
43- class CFGNorm :
51+ class CFGNorm ( io . ComfyNode ) :
4452 @classmethod
45- def INPUT_TYPES (s ):
46- return {"required" : {"model" : ("MODEL" ,),
47- "strength" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 100.0 , "step" : 0.01 }),
48- }}
49- RETURN_TYPES = ("MODEL" ,)
50- RETURN_NAMES = ("patched_model" ,)
51- FUNCTION = "patch"
52- CATEGORY = "advanced/guidance"
53- EXPERIMENTAL = True
54-
55- def patch (self , model , strength ):
53+ def define_schema (cls ) -> io .Schema :
54+ return io .Schema (
55+ node_id = "CFGNorm" ,
56+ category = "advanced/guidance" ,
57+ inputs = [
58+ io .Model .Input ("model" ),
59+ io .Float .Input ("strength" , default = 1.0 , min = 0.0 , max = 100.0 , step = 0.01 ),
60+ ],
61+ outputs = [io .Model .Output (display_name = "patched_model" )],
62+ is_experimental = True ,
63+ )
64+
65+ @classmethod
66+ def execute (cls , model , strength ) -> io .NodeOutput :
5667 m = model .clone ()
5768 def cfg_norm (args ):
5869 cond_p = args ['cond_denoised' ]
@@ -64,9 +75,17 @@ def cfg_norm(args):
6475 return pred_text_ * scale * strength
6576
6677 m .set_model_sampler_post_cfg_function (cfg_norm )
67- return (m , )
78+ return io .NodeOutput (m )
79+
80+
81+ class CfgExtension (ComfyExtension ):
82+ @override
83+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
84+ return [
85+ CFGZeroStar ,
86+ CFGNorm ,
87+ ]
88+
6889
69- NODE_CLASS_MAPPINGS = {
70- "CFGZeroStar" : CFGZeroStar ,
71- "CFGNorm" : CFGNorm ,
72- }
90+ async def comfy_entrypoint () -> CfgExtension :
91+ return CfgExtension ()
0 commit comments