Skip to content

Commit 49c35d0

Browse files
authored
Merge pull request #7 from blepping/feat_ancestral
Add ancestral sampling and advanced node
2 parents ece7be1 + 2ad92d1 commit 49c35d0

File tree

2 files changed

+195
-52
lines changed

2 files changed

+195
-52
lines changed

__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .custom_samplers import SamplerDistanceAdvanced
1+
from .custom_samplers import SamplerDistance, SamplerDistanceAdvanced
22
from .presets_to_add import extra_samplers
33

44
def add_samplers():
@@ -21,5 +21,6 @@ def add_samplers():
2121
add_samplers()
2222

2323
NODE_CLASS_MAPPINGS = {
24-
"SamplerDistance": SamplerDistanceAdvanced,
25-
}
24+
"SamplerDistance": SamplerDistance,
25+
"SamplerDistanceAdvanced": SamplerDistanceAdvanced,
26+
}

custom_samplers.py

Lines changed: 191 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from comfy.k_diffusion.sampling import trange, to_d
33
import comfy.model_patcher
44
import comfy.samplers
5+
from comfy.k_diffusion import sampling
6+
from comfy import model_sampling
57
from math import pi
68
mmnorm = lambda x: (x - x.min()) / (x.max() - x.min())
79
selfnorm = lambda x: x / x.norm()
@@ -70,12 +72,60 @@ def normalize_adjust(a,b,strength=1):
7072
a[~torch.isfinite(a)] = c[~torch.isfinite(a)]
7173
return a
7274

75+
def get_ancestral_step_ext(sigma, sigma_next, eta=1.0, is_rf=False):
76+
if sigma_next == 0 or eta == 0:
77+
return sigma_next, sigma_next * 0.0, 1.0
78+
if not is_rf:
79+
return (*sampling.get_ancestral_step(sigma, sigma_next, eta=eta), 1.0)
80+
# Referenced from ComfyUI.
81+
downstep_ratio = 1.0 + (sigma_next / sigma - 1.0) * eta
82+
sigma_down = sigma_next * downstep_ratio
83+
alpha_ip1, alpha_down = 1.0 - sigma_next, 1.0 - sigma_down
84+
sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
85+
x_coeff = alpha_ip1 / alpha_down
86+
return sigma_down, sigma_up, x_coeff
87+
88+
def internal_step(x, d, dt, sigma, sigma_next, sigma_up, x_coeff, noise_sampler):
89+
x = x + d * dt
90+
if sigma_up == 0 or noise_sampler is None:
91+
return x
92+
noise = noise_sampler(sigma, sigma_next).mul_(sigma_up)
93+
if x_coeff != 1:
94+
# x gets scaled for flow models.
95+
x *= x_coeff
96+
return x.add_(noise)
97+
98+
def fix_step_range(steps, start, end):
99+
if start < 0:
100+
start = steps + start
101+
if end < 0:
102+
end = steps + end
103+
start = max(0, min(steps - 1, start))
104+
end = max(0, min(steps - 1, end))
105+
return (end, start) if start > end else (start, end)
106+
73107
# Euler and CFGpp part taken from comfy_extras/nodes_advanced_samplers
74-
def distance_wrap(resample,resample_end=-1,cfgpp=False,sharpen=False,use_softmax=False,first_only=False,use_slerp=False,perp_step=False,smooth=False,use_negative=False):
108+
def distance_wrap(
109+
resample, resample_end=-1, cfgpp=False, sharpen=False, use_softmax=False,
110+
distance_first=0, distance_last=-1, eta_first=0, eta_last=-1, distance_eta_first=0, distance_eta_last=-1,
111+
use_slerp=False, perp_step=False, smooth=False, use_negative=False, eta=0.0, s_noise=1.0,
112+
distance_step_eta=0.0, distance_step_s_noise=1.0, distance_step_seed_offset=42,
113+
):
75114
@torch.no_grad()
76-
def sample_distance_advanced(model, x, sigmas, extra_args=None, callback=None, disable=None):
115+
def sample_distance_advanced(model, x, sigmas, eta=eta, s_noise=s_noise, noise_sampler=None, distance_step_noise_sampler=None, extra_args=None, callback=None, disable=None):
116+
nonlocal distance_first, distance_last, eta_first, eta_last, distance_eta_first, distance_eta_last
117+
77118
extra_args = {} if extra_args is None else extra_args
119+
seed = extra_args.get("seed")
120+
dstep_noise_sampler = None if distance_step_eta == 0 else distance_step_noise_sampler or noise_sampler or sampling.default_noise_sampler(x, seed=seed + distance_step_seed_offset if seed is not None else None)
121+
noise_sampler = None if eta == 0 else noise_sampler or sampling.default_noise_sampler(x, seed=seed)
122+
is_rf = isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST)
78123
uncond = None
124+
steps = len(sigmas) - 1
125+
126+
distance_first, distance_last = fix_step_range(steps, distance_first, distance_last)
127+
eta_first, eta_last = fix_step_range(steps, eta_first, eta_last)
128+
distance_eta_first, distance_eta_last = fix_step_range(steps, distance_eta_first, distance_eta_last)
79129

80130
if cfgpp or use_negative:
81131
uncond = None
@@ -96,58 +146,66 @@ def post_cfg_function(args):
96146
current_resample = resample
97147
total = 0
98148
s_in = x.new_ones([x.shape[0]])
99-
for i in trange(len(sigmas) - 1, disable=disable):
100-
sigma_hat = sigmas[i]
149+
for i in trange(steps, disable=disable):
150+
use_distance = distance_first <= i <= distance_last
151+
use_eta = eta_first <= i <= eta_last
152+
use_distance_eta = distance_eta_first <= i <= distance_eta_last
153+
sigma, sigma_next = sigmas[i:i + 2]
154+
sigma_down, sigma_up, x_coeff = get_ancestral_step_ext(sigma, sigma_next, eta=eta if use_eta else 0.0, is_rf=is_rf)
155+
sigma_up *= s_noise
156+
dstep_sigma_down, dstep_sigma_up, dstep_x_coeff = get_ancestral_step_ext(sigma, sigma_next, eta=distance_step_eta if use_distance_eta else 0.0, is_rf=is_rf)
157+
dstep_sigma_up *= distance_step_s_noise
101158

102-
res_mul = progression(sigma_hat)
159+
res_mul = progression(sigma)
103160
if resample_end >= 0:
104161
resample_steps = max(min(current_resample,resample_end),min(max(current_resample,resample_end),int(current_resample * res_mul + resample_end * (1 - res_mul))))
105162
else:
106163
resample_steps = current_resample
107164

108-
denoised = model(x, sigma_hat * s_in, **extra_args)
165+
denoised = model(x, sigma * s_in, **extra_args)
109166
total += 1
110167

111168
if cfgpp and torch.any(uncond):
112-
d = to_d(x - denoised + uncond, sigmas[i], denoised)
169+
d = to_d(x - denoised + uncond, sigma, denoised)
113170
else:
114-
d = to_d(x, sigma_hat, denoised)
171+
d = to_d(x, sigma, denoised)
115172

116173
if callback is not None:
117-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
118-
dt = sigmas[i + 1] - sigma_hat
174+
callback({'x': x, 'i': i, 'sigma': sigmas, 'sigma_hat': sigma, 'denoised': denoised})
175+
dt = sigma_down - sigma
176+
dstep_dt = dstep_sigma_down - sigma
119177

120-
if sigmas[i + 1] == 0 or resample_steps == 0 or (i > 0 and first_only):
178+
if sigma_next == 0 or resample_steps == 0 or not use_distance:
121179
# Euler method
122-
x = x + d * dt
123-
else:
124-
# not Euler method
125-
x_n = [d]
126-
for re_step in range(resample_steps):
127-
x_new = x + d * dt
128-
new_denoised = model(x_new, sigmas[i + 1] * s_in, **extra_args)
129-
if smooth:
130-
new_denoised = new_denoised.abs().pow(1 / new_denoised.std().sqrt()) * new_denoised.sign()
131-
new_denoised = new_denoised.div(new_denoised.std().sqrt())
132-
total += 1
133-
if cfgpp and torch.any(uncond):
134-
new_d = to_d(x_new - new_denoised + uncond, sigmas[i + 1], new_denoised)
135-
else:
136-
new_d = to_d(x_new, sigmas[i + 1] * s_in, new_denoised)
137-
x_n.append(new_d)
138-
if re_step == 0:
139-
d = (new_d + d) / 2
140-
else:
141-
u = uncond if (use_negative and uncond is not None and torch.any(uncond)) else None
142-
d = fast_distance_weights(torch.stack(x_n), use_softmax=use_softmax, use_slerp=use_slerp, uncond=u)
143-
if sharpen or perp_step:
144-
if sharpen and d_prev is not None:
145-
d = normalize_adjust(d, d_prev, 1)
146-
elif perp_step and d_prev is not None:
147-
d = diff_step(d, d_prev, 0.5)
148-
d_prev = d.clone()
149-
x_n.append(d)
150-
x = x + d * dt
180+
x = internal_step(x, d, dt, sigma, sigma_next, sigma_up, x_coeff, noise_sampler)
181+
continue
182+
# not Euler method
183+
x_n = [d]
184+
for re_step in trange(resample_steps, initial=1, disable=disable or resample_steps < 2, leave=False, desc=" Distance"):
185+
x_new = internal_step(x, d, dstep_dt, sigma, sigma_next, dstep_sigma_up, dstep_x_coeff, dstep_noise_sampler)
186+
new_denoised = model(x_new, sigma_next * s_in, **extra_args)
187+
if smooth:
188+
new_denoised = new_denoised.abs().pow(1 / new_denoised.std().sqrt()) * new_denoised.sign()
189+
new_denoised = new_denoised.div(new_denoised.std().sqrt())
190+
total += 1
191+
if cfgpp and torch.any(uncond):
192+
new_d = to_d(x_new - new_denoised + uncond, sigma_next, new_denoised)
193+
else:
194+
new_d = to_d(x_new, sigma_next * s_in, new_denoised)
195+
x_n.append(new_d)
196+
if re_step == 0:
197+
d = (new_d + d) / 2
198+
continue
199+
u = uncond if (use_negative and uncond is not None and torch.any(uncond)) else None
200+
d = fast_distance_weights(torch.stack(x_n), use_softmax=use_softmax, use_slerp=use_slerp, uncond=u)
201+
if sharpen or perp_step:
202+
if sharpen and d_prev is not None:
203+
d = normalize_adjust(d, d_prev, 1)
204+
elif perp_step and d_prev is not None:
205+
d = diff_step(d, d_prev, 0.5)
206+
d_prev = d.clone()
207+
x_n.append(d)
208+
x = internal_step(x, d, dt, sigma, sigma_next, sigma_up, x_coeff, noise_sampler)
151209
return x
152210
return sample_distance_advanced
153211

@@ -202,19 +260,103 @@ def simplified_euler(model, x, sigmas, extra_args=None, callback=None, disable=N
202260
x = x + d * dt
203261
return x
204262

205-
class SamplerDistanceAdvanced:
263+
class SamplerDistanceBase:
264+
_DISTANCE_OPTIONS = None # All options by default.
265+
_DISTANCE_PARAMS = {
266+
"resample": ("INT", {
267+
"default": 3, "min": -1, "max": 32, "step": 1,
268+
"tooltip": "0 all along gives Euler. 1 gives Heun.\nAnything starting from 2 will use the distance method.\n-1 will do remaining steps + 1 as the resample value. This can be pretty slow.",
269+
}),
270+
"resample_end": ("INT", {
271+
"default": -1, "min": -1, "max": 32, "step": 1,
272+
"tooltip": "How many resamples for the end. -1 means constant.",
273+
}),
274+
"cfgpp": ("BOOLEAN", {
275+
"default": True,
276+
"tooltip": "Controls whether to use CFG++ sampling. When enabled, you should set CFG to a fairly low value.",
277+
}),
278+
"eta": ("FLOAT", {
279+
"default": 0.0, "min": 0.0, "max": 32.0, "step": 0.01,
280+
"tooltip": "Controls the ancestralness of the main sampler steps. 0.0 means to use non-ancestral sampling. Note: May not work well with some of the other options.",
281+
}),
282+
"s_noise": ("FLOAT", {
283+
"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01,
284+
"tooltip": "Scale factor for ancestral noise added during sampling. Generally should be left at 1.0 and only has an effect when ancestral sampling is used.",
285+
}),
286+
"distance_step_eta": ("FLOAT", {
287+
"default": 0.0, "min": 0.0, "max": 32.0, "step": 0.01,
288+
"tooltip": "Experimental option that allows using ancestral sampling for the internal distance steps. When used, should generally be a fairly low value such as 0.25. 0.0 means to use non-ancestral sampling for the internal distance steps.",
289+
}),
290+
"distance_step_s_noise": ("FLOAT", {
291+
"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01,
292+
"tooltip": "Scale factor for ancestral noise added in the internal distance steps. Generally should be left at 1.0 and only has an effect when distance_step_eta is non-zero.",
293+
}),
294+
"use_softmax": ("BOOLEAN", {
295+
"default": False,
296+
"tooltip": "Rather than using a min/max normalization and an exponent will use a softmax instead.",
297+
}),
298+
"use_slerp": ("BOOLEAN", {
299+
"default": False,
300+
"tooltip": "Will SLERP the predictions instead of doing a weighted average. The difference is more obvious when using use_negative.",
301+
}),
302+
"perp_step": ("BOOLEAN", {
303+
"default": False,
304+
"tooltip": "Experimental, not yet recommended.",
305+
}),
306+
"use_negative": ("BOOLEAN", {
307+
"default": False,
308+
"tooltip": "Will use the negative prediction to prepare the distance scores. This tends to give images with less errors from my testing.",
309+
}),
310+
"smooth": ("BOOLEAN", {
311+
"default": False,
312+
"tooltip": "Not recommended, will make everything brighter. Not smoother.",
313+
}),
314+
"sharpen": ("BOOLEAN", {
315+
"default": False,
316+
"tooltip": "Not recommended, attempts to sharpen the results but instead tends to make things fuzzy.",
317+
}),
318+
"distance_first": ("INT", {
319+
"default": 0, "min": -10000, "max": 10000, "step": 1,
320+
"tooltip": "First step to use distance sampling. You can use negative values to count from the end. Note: Steps are zero-based.",
321+
}),
322+
"distance_last": ("INT", {
323+
"default": -1, "min": -10000, "max": 10000, "step": 1,
324+
"tooltip": "Last step to use distance sampling. You can use negative values to count from the end. Note: Steps are zero-based.",
325+
}),
326+
"eta_first": ("INT", {
327+
"default": 0, "min": -10000, "max": 10000, "step": 1,
328+
"tooltip": "First step to use ancestral sampling. Only applies when ETA is non-zero. You can use negative values to count from the end. Note: Steps are zero-based.",
329+
}),
330+
"eta_last": ("INT", {
331+
"default": -1, "min": -10000, "max": 10000, "step": 1,
332+
"tooltip": "Last step to use ancestral sampling. Only applies when ETA is non-zero. You can use negative values to count from the end. Note: Steps are zero-based.",
333+
}),
334+
"distance_eta_first": ("INT", {
335+
"default": 0, "min": -10000, "max": 10000, "step": 1,
336+
"tooltip": "First step to use ancestral sampling for the distance steps. Only applies when distance ETA is non-zero. You can use negative values to count from the end. Note: Steps are zero-based.",
337+
}),
338+
"distance_eta_last": ("INT", {
339+
"default": -1, "min": -10000, "max": 10000, "step": 1,
340+
"tooltip": "Last step to use ancestral sampling for the distance steps. Only applies when distance ETA is non-zero. You can use negative values to count from the end. Note: Steps are zero-based.",
341+
}),
342+
}
343+
206344
@classmethod
207345
def INPUT_TYPES(s):
208-
return {"required": {"resample": ("INT", {"default": 3, "min": -1, "max": 32, "step": 1,
209-
"tooltip":"0 all along gives Euler. 1 gives Heun.\nAnything starting from 2 will use the distance method.\n-1 will do remaining steps + 1 as the resample value. This can be pretty slow."}),
210-
"resample_end": ("INT", {"default": -1, "min": -1, "max": 32, "step": 1, "tooltip":"How many resamples for the end. -1 means constant."}),
211-
"cfgpp" : ("BOOLEAN", {"default": True}),
212-
}}
346+
if s._DISTANCE_OPTIONS is None:
347+
return {"required": s._DISTANCE_PARAMS.copy()}
348+
return {"required": {k: s._DISTANCE_PARAMS[k] for k in s._DISTANCE_OPTIONS}}
349+
213350
RETURN_TYPES = ("SAMPLER",)
214351
CATEGORY = "sampling/custom_sampling/samplers"
215352
FUNCTION = "get_sampler"
216353

217-
def get_sampler(self,resample,resample_end,cfgpp):
218-
sampler = comfy.samplers.KSAMPLER(
219-
distance_wrap(resample=resample,cfgpp=cfgpp,resample_end=resample_end))
354+
def get_sampler(self, **kwargs):
355+
sampler = comfy.samplers.KSAMPLER(distance_wrap(**kwargs))
220356
return (sampler, )
357+
358+
class SamplerDistance(SamplerDistanceBase):
359+
_DISTANCE_OPTIONS = ("resample", "resample_end", "cfgpp")
360+
361+
class SamplerDistanceAdvanced(SamplerDistanceBase):
362+
pass # Includes all options by default.

0 commit comments

Comments
 (0)