Skip to content

Commit 1bcda6d

Browse files
WIP way to support multi multi dimensional latents. (#10456)
1 parent a1864c0 commit 1bcda6d

File tree

5 files changed

+158
-15
lines changed

5 files changed

+158
-15
lines changed

comfy/model_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,14 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran
197197
extra_conds[o] = extra
198198

199199
t = self.process_timestep(t, x=x, **extra_conds)
200-
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
201-
return self.model_sampling.calculate_denoised(sigma, model_output, x)
200+
if "latent_shapes" in extra_conds:
201+
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
202+
203+
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
204+
if len(model_output) > 1 and not torch.is_tensor(model_output):
205+
model_output, _ = utils.pack_latents(model_output)
206+
207+
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
202208

203209
def process_timestep(self, timestep, **kwargs):
204210
return timestep

comfy/nested_tensor.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
3+
class NestedTensor:
4+
def __init__(self, tensors):
5+
self.tensors = list(tensors)
6+
self.is_nested = True
7+
8+
def _copy(self):
9+
return NestedTensor(self.tensors)
10+
11+
def apply_operation(self, other, operation):
12+
o = self._copy()
13+
if isinstance(other, NestedTensor):
14+
for i, t in enumerate(o.tensors):
15+
o.tensors[i] = operation(t, other.tensors[i])
16+
else:
17+
for i, t in enumerate(o.tensors):
18+
o.tensors[i] = operation(t, other)
19+
return o
20+
21+
def __add__(self, b):
22+
return self.apply_operation(b, lambda x, y: x + y)
23+
24+
def __sub__(self, b):
25+
return self.apply_operation(b, lambda x, y: x - y)
26+
27+
def __mul__(self, b):
28+
return self.apply_operation(b, lambda x, y: x * y)
29+
30+
# def __itruediv__(self, b):
31+
# return self.apply_operation(b, lambda x, y: x / y)
32+
33+
def __truediv__(self, b):
34+
return self.apply_operation(b, lambda x, y: x / y)
35+
36+
def __getitem__(self, *args, **kwargs):
37+
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
38+
39+
def unbind(self):
40+
return self.tensors
41+
42+
def to(self, *args, **kwargs):
43+
o = self._copy()
44+
for i, t in enumerate(o.tensors):
45+
o.tensors[i] = t.to(*args, **kwargs)
46+
return o
47+
48+
def new_ones(self, *args, **kwargs):
49+
return self.tensors[0].new_ones(*args, **kwargs)
50+
51+
def float(self):
52+
return self.to(dtype=torch.float)
53+
54+
def chunk(self, *args, **kwargs):
55+
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
56+
57+
def size(self):
58+
return self.tensors[0].size()
59+
60+
@property
61+
def shape(self):
62+
return self.tensors[0].shape
63+
64+
@property
65+
def ndim(self):
66+
dims = 0
67+
for t in self.tensors:
68+
dims = max(t.ndim, dims)
69+
return dims
70+
71+
@property
72+
def device(self):
73+
return self.tensors[0].device
74+
75+
@property
76+
def dtype(self):
77+
return self.tensors[0].dtype
78+
79+
@property
80+
def layout(self):
81+
return self.tensors[0].layout
82+
83+
84+
def cat_nested(tensors, *args, **kwargs):
85+
cated_tensors = []
86+
for i in range(len(tensors[0].tensors)):
87+
tens = []
88+
for j in range(len(tensors)):
89+
tens.append(tensors[j].tensors[i])
90+
cated_tensors.append(torch.cat(tens, *args, **kwargs))
91+
return NestedTensor(cated_tensors)

comfy/sample.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
import comfy.utils
55
import numpy as np
66
import logging
7+
import comfy.nested_tensor
78

8-
def prepare_noise(latent_image, seed, noise_inds=None):
9-
"""
10-
creates random noise given a latent image and a seed.
11-
optional arg skip can be used to skip and discard x number of noise generations for a given seed
12-
"""
13-
generator = torch.manual_seed(seed)
9+
def prepare_noise_inner(latent_image, generator, noise_inds=None):
1410
if noise_inds is None:
1511
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
1612

@@ -22,9 +18,28 @@ def prepare_noise(latent_image, seed, noise_inds=None):
2218
noises.append(noise)
2319
noises = [noises[i] for i in inverse]
2420
noises = torch.cat(noises, axis=0)
21+
22+
def prepare_noise(latent_image, seed, noise_inds=None):
23+
"""
24+
creates random noise given a latent image and a seed.
25+
optional arg skip can be used to skip and discard x number of noise generations for a given seed
26+
"""
27+
generator = torch.manual_seed(seed)
28+
29+
if latent_image.is_nested:
30+
tensors = latent_image.unbind()
31+
noises = []
32+
for t in tensors:
33+
noises.append(prepare_noise_inner(t, generator, noise_inds))
34+
noises = comfy.nested_tensor.NestedTensor(noises)
35+
else:
36+
noises = prepare_noise_inner(latent_image, generator, noise_inds)
37+
2538
return noises
2639

2740
def fix_empty_latent_channels(model, latent_image):
41+
if latent_image.is_nested:
42+
return latent_image
2843
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
2944
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
3045
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)

comfy/samplers.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, *
782782
return KSAMPLER(sampler_function, extra_options, inpaint_options)
783783

784784

785-
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
785+
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
786786
for k in conds:
787787
conds[k] = conds[k][:]
788788
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
792792

793793
if hasattr(model, 'extra_conds'):
794794
for k in conds:
795-
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
795+
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
796796

797797
#make sure each cond area has an opposite one with the same area
798798
for k in conds:
@@ -962,11 +962,11 @@ def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
962962
def predict_noise(self, x, timestep, model_options={}, seed=None):
963963
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
964964

965-
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
965+
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
966966
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
967967
latent_image = self.inner_model.process_latent_in(latent_image)
968968

969-
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
969+
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
970970

971971
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
972972
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@@ -980,7 +980,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas
980980
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
981981
return self.inner_model.process_latent_out(samples.to(torch.float32))
982982

983-
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
983+
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
984984
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
985985
device = self.model_patcher.load_device
986986

@@ -994,7 +994,7 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None,
994994

995995
try:
996996
self.model_patcher.pre_run()
997-
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
997+
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
998998
finally:
999999
self.model_patcher.cleanup()
10001000

@@ -1007,6 +1007,12 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba
10071007
if sigmas.shape[-1] == 0:
10081008
return latent_image
10091009

1010+
if latent_image.is_nested:
1011+
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
1012+
noise, _ = comfy.utils.pack_latents(noise.unbind())
1013+
else:
1014+
latent_shapes = [latent_image.shape]
1015+
10101016
self.conds = {}
10111017
for k in self.original_conds:
10121018
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@@ -1026,14 +1032,17 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba
10261032
self,
10271033
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
10281034
)
1029-
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
1035+
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
10301036
finally:
10311037
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
10321038
self.model_options = orig_model_options
10331039
self.model_patcher.hook_mode = orig_hook_mode
10341040
self.model_patcher.restore_hook_patches()
10351041

10361042
del self.conds
1043+
1044+
if len(latent_shapes) > 1:
1045+
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
10371046
return output
10381047

10391048

comfy/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
11061106
dim=1
11071107
)
11081108
return out
1109+
1110+
def pack_latents(latents):
1111+
latent_shapes = []
1112+
tensors = []
1113+
for tensor in latents:
1114+
latent_shapes.append(tensor.shape)
1115+
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
1116+
1117+
latent = torch.cat(tensors, dim=-1)
1118+
return latent, latent_shapes
1119+
1120+
def unpack_latents(combined_latent, latent_shapes):
1121+
if len(latent_shapes) > 1:
1122+
output_tensors = []
1123+
for shape in latent_shapes:
1124+
cut = math.prod(shape[1:])
1125+
tens = combined_latent[:, :, :cut]
1126+
combined_latent = combined_latent[:, :, cut:]
1127+
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
1128+
else:
1129+
output_tensors = combined_latent
1130+
return output_tensors

0 commit comments

Comments
 (0)