Skip to content

Commit

Permalink
Fix prompt editing.
Browse files Browse the repository at this point in the history
For #107
  • Loading branch information
shiimizu committed Nov 8, 2024
1 parent b0a0272 commit f74df5b
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions smZNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import comfy
import math
import ctypes
from decimal import Decimal
from functools import partial
from random import getrandbits
import comfy.sdxl_clip
Expand Down Expand Up @@ -61,13 +62,26 @@ def find_nearest(a,b):

def get_area_and_mult(*args, **kwargs):
conds = args[0]
if 'start_step' in conds and 'end_step' in conds:
if 'start_perc' in conds and 'end_perc' in conds and "init_steps" in conds:
timestep_in = args[2]
sigmas = store.sigmas
if conds['init_steps'] == sigmas.shape[0] - 1:
total = Decimal(sigmas.shape[0] - 1)
else:
sigmas_ = store.sigmas.unique(sorted=True).sort(descending=True)[0]
if len(sigmas) == len(sigmas_):
# Sampler Custom with sigmas: no change
total = Decimal(sigmas.shape[0] - 1)
else:
# Sampler with restarts: dedup the sigmas and add one
sigmas = sigmas_
total = Decimal(sigmas.shape[0] + 1)
ts_in = find_nearest(timestep_in, sigmas)
cur_i = ss[0].item() if (ss:=(sigmas == ts_in).nonzero()).shape[0] != 0 else 0
cur = cur_i / (sigmas.shape[0] - 1)
if not (cur >= conds['start_step'] and cur < conds['end_step']):
cur = Decimal(cur_i) / total
start = conds['start_perc']
end = conds['end_perc']
if not (cur >= start and cur < end):
return None
return store.get_area_and_mult(*args, **kwargs)

Expand Down Expand Up @@ -231,18 +245,17 @@ def HijackClipComfy(clip):
except Exception: ...
del store_orig

def transform_schedules(schedules, weight=None, with_weight=False):
def transform_schedules(steps, schedules, weight=None, with_weight=False):
end_steps = [schedule.end_at_step for schedule in schedules]
start_end_pairs = list(zip([0] + end_steps[:-1], end_steps))
with_steps = len(schedules) > 1
steps = len(schedules)
with_prompt_editing = len(schedules) > 1

def process(schedule, start_step, end_step):
nonlocal with_steps, steps
nonlocal with_prompt_editing
d = schedule.cond.copy()
d.pop('cond', None)
if with_steps:
d |= {"start_step": start_step / steps, "end_step": end_step / steps}
if with_prompt_editing:
d |= {"start_perc": Decimal(start_step) / Decimal(steps), "end_perc": Decimal(end_step) / Decimal(steps), "init_steps": steps}
if weight is not None and with_weight:
d['weight'] = weight
return d
Expand All @@ -257,20 +270,20 @@ def process(schedule, start_step, end_step):
def flatten(nested_list):
return [item for sublist in nested_list for item in sublist]

def convert_schedules_to_comfy(schedules, multi=False):
def convert_schedules_to_comfy(schedules, steps, multi=False):
if multi:
out = [[transform_schedules(x.schedules, x.weight, len(batch)>1) for x in batch] for batch in schedules.batch]
out = [[transform_schedules(steps, x.schedules, x.weight, len(batch)>1) for x in batch] for batch in schedules.batch]
out = flatten(out)
else:
out = [transform_schedules(sublist) for sublist in schedules]
out = [transform_schedules(steps, sublist) for sublist in schedules]
return flatten(out)

def get_learned_conditioning(model, prompts, steps, multi=False, *args, **kwargs):
if multi:
schedules = prompt_parser.get_multicond_learned_conditioning(model, prompts, steps, *args, **kwargs)
else:
schedules = prompt_parser.get_learned_conditioning(model, prompts, steps, *args, **kwargs)
schedules_c = convert_schedules_to_comfy(schedules, multi)
schedules_c = convert_schedules_to_comfy(schedules, steps, multi)
return schedules_c

class CustomList(list):
Expand Down

0 comments on commit f74df5b

Please sign in to comment.