From 6b58f6bd5ab0d876b912aaa9728ed29e4637dc04 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Wed, 23 Apr 2025 21:30:30 +0100 Subject: [PATCH 1/3] vae handle HIP OOM exceptions --- comfy/model_management.py | 8 ++++++++ comfy/sd.py | 5 +++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 054291432b7b..9935558b482b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -235,6 +235,14 @@ def mac_version(): except: OOM_EXCEPTION = Exception + +def is_oom_exception(ex): + if isinstance(ex, OOM_EXCEPTION): + return True + # handle also other kinds of oom, e.g. "HIP error: out of memory" + msg = str(ex) + return "out of memory" in msg + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: diff --git a/comfy/sd.py b/comfy/sd.py index cd13ab5f0bd1..eb8379cc788c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -640,8 +640,9 @@ def encode(self, pixel_samples): if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples[x:x + batch_number] = out - - except model_management.OOM_EXCEPTION: + except Exception as ex: + if not model_management.is_oom_exception(ex): + raise logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") if self.latent_dim == 3: tile = 256 From e085cc478ca8887fb847351744773f68fe99294e Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Wed, 23 Apr 2025 21:33:58 +0100 Subject: [PATCH 2/3] vae decode handle HIP oom exceptions --- comfy/sd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index eb8379cc788c..b96366e724df 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -577,7 +577,9 @@ def decode(self, samples_in, vae_options={}): if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as ex: + if not model_management.is_oom_exception(ex): + raise logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: From a19cb1a13b6ce094e4e63256063d2329093f9822 Mon Sep 17 00:00:00 2001 From: Alex Butler Date: Sat, 21 Jun 2025 14:20:41 +0100 Subject: [PATCH 3/3] Use is_oom_exception for all exception checks --- comfy/ldm/modules/attention.py | 4 +++- comfy/ldm/modules/diffusionmodules/model.py | 8 ++++++-- comfy/ldm/modules/sub_quadratic_attention.py | 4 +++- comfy_extras/nodes_upscale_model.py | 4 +++- execution.py | 2 +- 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee98e..7e676f596d58 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -321,7 +321,9 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + if not model_management.is_oom_exception(e): + raise if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8162742cf034..e7ffc8b8dc97 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -232,7 +232,9 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + if not model_management.is_oom_exception(e): + raise model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -287,7 +289,9 @@ def pytorch_attention(q, k, v): try: out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as ex: + if not model_management.is_oom_exception(ex): + raise logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) return out diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c208..cc62a8f0b252 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,9 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as ex: + if not model_management.is_oom_exception(ex): + raise logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 04c948341296..120a783c1e77 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -68,7 +68,9 @@ def upscale(self, upscale_model, image): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except model_management.OOM_EXCEPTION as e: + except Exception as e: + if not model_management.is_oom_exception(e): + raise tile //= 2 if tile < 128: raise e diff --git a/execution.py b/execution.py index f6006fa12374..cfdd0edd2668 100644 --- a/execution.py +++ b/execution.py @@ -431,7 +431,7 @@ def pre_execute_cb(call_index): logging.error(traceback.format_exc()) tips = "" - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + if comfy.model_management.is_oom_exception(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models()