Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,9 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
if not model_management.is_oom_exception(e):
raise
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
Expand Down
8 changes: 6 additions & 2 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
if not model_management.is_oom_exception(e):
raise
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
Expand Down Expand Up @@ -287,7 +289,9 @@ def pytorch_attention(q, k, v):
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
except Exception as ex:
if not model_management.is_oom_exception(ex):
raise
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
Expand Down
4 changes: 3 additions & 1 deletion comfy/ldm/modules/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except Exception as ex:
if not model_management.is_oom_exception(ex):
raise
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)
Expand Down
8 changes: 8 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def mac_version():
except:
OOM_EXCEPTION = Exception


def is_oom_exception(ex):
if isinstance(ex, OOM_EXCEPTION):
return True
# handle also other kinds of oom, e.g. "HIP error: out of memory"
msg = str(ex)
return "out of memory" in msg

XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:
Expand Down
9 changes: 6 additions & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,9 @@ def decode(self, samples_in, vae_options={}):
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
except Exception as ex:
if not model_management.is_oom_exception(ex):
raise
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
Expand Down Expand Up @@ -640,8 +642,9 @@ def encode(self, pixel_samples):
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out

except model_management.OOM_EXCEPTION:
except Exception as ex:
if not model_management.is_oom_exception(ex):
raise
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
if self.latent_dim == 3:
tile = 256
Expand Down
4 changes: 3 additions & 1 deletion comfy_extras/nodes_upscale_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def upscale(self, upscale_model, image):
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
except Exception as e:
if not model_management.is_oom_exception(e):
raise
tile //= 2
if tile < 128:
raise e
Expand Down
2 changes: 1 addition & 1 deletion execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def pre_execute_cb(call_index):
logging.error(traceback.format_exc())
tips = ""

if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
if comfy.model_management.is_oom_exception(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
Expand Down