@@ -782,7 +782,7 @@ def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable, *
782782 return KSAMPLER (sampler_function , extra_options , inpaint_options )
783783
784784
785- def process_conds (model , noise , conds , device , latent_image = None , denoise_mask = None , seed = None ):
785+ def process_conds (model , noise , conds , device , latent_image = None , denoise_mask = None , seed = None , latent_shapes = None ):
786786 for k in conds :
787787 conds [k ] = conds [k ][:]
788788 resolve_areas_and_cond_masks_multidim (conds [k ], noise .shape [2 :], device )
@@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
792792
793793 if hasattr (model , 'extra_conds' ):
794794 for k in conds :
795- conds [k ] = encode_model_conds (model .extra_conds , conds [k ], noise , device , k , latent_image = latent_image , denoise_mask = denoise_mask , seed = seed )
795+ conds [k ] = encode_model_conds (model .extra_conds , conds [k ], noise , device , k , latent_image = latent_image , denoise_mask = denoise_mask , seed = seed , latent_shapes = latent_shapes )
796796
797797 #make sure each cond area has an opposite one with the same area
798798 for k in conds :
@@ -962,11 +962,11 @@ def outer_predict_noise(self, x, timestep, model_options={}, seed=None):
962962 def predict_noise (self , x , timestep , model_options = {}, seed = None ):
963963 return sampling_function (self .inner_model , x , timestep , self .conds .get ("negative" , None ), self .conds .get ("positive" , None ), self .cfg , model_options = model_options , seed = seed )
964964
965- def inner_sample (self , noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed ):
965+ def inner_sample (self , noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed , latent_shapes = None ):
966966 if latent_image is not None and torch .count_nonzero (latent_image ) > 0 : #Don't shift the empty latent image.
967967 latent_image = self .inner_model .process_latent_in (latent_image )
968968
969- self .conds = process_conds (self .inner_model , noise , self .conds , device , latent_image , denoise_mask , seed )
969+ self .conds = process_conds (self .inner_model , noise , self .conds , device , latent_image , denoise_mask , seed , latent_shapes = latent_shapes )
970970
971971 extra_model_options = comfy .model_patcher .create_model_options_clone (self .model_options )
972972 extra_model_options .setdefault ("transformer_options" , {})["sample_sigmas" ] = sigmas
@@ -980,7 +980,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas
980980 samples = executor .execute (self , sigmas , extra_args , callback , noise , latent_image , denoise_mask , disable_pbar )
981981 return self .inner_model .process_latent_out (samples .to (torch .float32 ))
982982
983- def outer_sample (self , noise , latent_image , sampler , sigmas , denoise_mask = None , callback = None , disable_pbar = False , seed = None ):
983+ def outer_sample (self , noise , latent_image , sampler , sigmas , denoise_mask = None , callback = None , disable_pbar = False , seed = None , latent_shapes = None ):
984984 self .inner_model , self .conds , self .loaded_models = comfy .sampler_helpers .prepare_sampling (self .model_patcher , noise .shape , self .conds , self .model_options )
985985 device = self .model_patcher .load_device
986986
@@ -994,7 +994,7 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None,
994994
995995 try :
996996 self .model_patcher .pre_run ()
997- output = self .inner_sample (noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed )
997+ output = self .inner_sample (noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed , latent_shapes = latent_shapes )
998998 finally :
999999 self .model_patcher .cleanup ()
10001000
@@ -1007,6 +1007,12 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba
10071007 if sigmas .shape [- 1 ] == 0 :
10081008 return latent_image
10091009
1010+ if latent_image .is_nested :
1011+ latent_image , latent_shapes = comfy .utils .pack_latents (latent_image .unbind ())
1012+ noise , _ = comfy .utils .pack_latents (noise .unbind ())
1013+ else :
1014+ latent_shapes = [latent_image .shape ]
1015+
10101016 self .conds = {}
10111017 for k in self .original_conds :
10121018 self .conds [k ] = list (map (lambda a : a .copy (), self .original_conds [k ]))
@@ -1026,14 +1032,17 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba
10261032 self ,
10271033 comfy .patcher_extension .get_all_wrappers (comfy .patcher_extension .WrappersMP .OUTER_SAMPLE , self .model_options , is_model_options = True )
10281034 )
1029- output = executor .execute (noise , latent_image , sampler , sigmas , denoise_mask , callback , disable_pbar , seed )
1035+ output = executor .execute (noise , latent_image , sampler , sigmas , denoise_mask , callback , disable_pbar , seed , latent_shapes = latent_shapes )
10301036 finally :
10311037 cast_to_load_options (self .model_options , device = self .model_patcher .offload_device )
10321038 self .model_options = orig_model_options
10331039 self .model_patcher .hook_mode = orig_hook_mode
10341040 self .model_patcher .restore_hook_patches ()
10351041
10361042 del self .conds
1043+
1044+ if len (latent_shapes ) > 1 :
1045+ output = comfy .nested_tensor .NestedTensor (comfy .utils .unpack_latents (output , latent_shapes ))
10371046 return output
10381047
10391048
0 commit comments