Skip to content

[Bugfix][ROCm] Fix worker startup OOM on ROCm by skipping unreliable cudagraph memory profiling#36720

Merged
MatthewBonanni merged 2 commits intovllm-project:mainfrom
JartX:fix/rocm-cudagraph-memory-profiling-startup-oom
Mar 17, 2026
Merged

[Bugfix][ROCm] Fix worker startup OOM on ROCm by skipping unreliable cudagraph memory profiling#36720
MatthewBonanni merged 2 commits intovllm-project:mainfrom
JartX:fix/rocm-cudagraph-memory-profiling-startup-oom

Conversation

@JartX
Copy link
Contributor

@JartX JartX commented Mar 10, 2026

Problem

On ROCm/HIP platforms, vLLM fails to start with:

ValueError: Free memory on device cuda:0 (1.33/23.98 GiB) on startup is less
than desired GPU memory utilization (0.95, 22.79 GiB).

This was introduced by the pr
[UX][Startup] Account for CUDA graphs during memory profiling
#30515), which removed profile_cudagraph_memory() and
the explicit recalculation of torch_peak_increase and non_kv_cache_memory.

Root Cause

profile_cudagraph_memory() uses torch.cuda.mem_get_info() and
graph_pool_handle() which behave differently on ROCm/HIP and can produce
incorrect or negative memory estimates. When this estimate is applied to
available_kv_cache_memory_bytes, the result is inflated beyond the actual
free GPU memory, causing request_memory() to fail.

The workaround VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 happened to work
because it forced the (incorrect) estimate to be applied, accidentally
compensating for the missing torch_peak_increase recalculation.

Fix

Disable memory profile on Rocm

  • ROCm: verified worker starts without OOM error
  • CUDA: no behavior change, existing tests pass

@JartX JartX requested a review from njhill as a code owner March 10, 2026 23:16
@mergify mergify bot added nvidia rocm Related to AMD ROCm v1 bug Something isn't working labels Mar 10, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 10, 2026
@JartX
Copy link
Contributor Author

JartX commented Mar 10, 2026

@AndreasKaratzas @tjtanaa

@mergify
Copy link

mergify bot commented Mar 10, 2026

Hi @JartX, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an out-of-memory error on ROCm platforms during worker startup by skipping the unreliable CUDA graph memory profiling. The changes in vllm/v1/worker/gpu_worker.py correctly implement this fix by adding a platform check. The logic is sound and directly addresses the root cause described. I've found one minor issue: a stray string in a docstring that appears to be a leftover from development and should be removed.

…ory estimation error

Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX force-pushed the fix/rocm-cudagraph-memory-profiling-startup-oom branch from d2fce77 to 5cf8086 Compare March 10, 2026 23:26
@JartX
Copy link
Contributor Author

JartX commented Mar 10, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an Out-Of-Memory (OOM) error on ROCm platforms during worker startup. The fix involves skipping the CUDA graph memory profiling on ROCm, as it has been identified as unreliable on that platform. This is achieved by adding and not current_platform.is_rocm() to the condition for profiling. The change is correct, well-targeted, and should resolve the described issue without side effects on other platforms. The added comments also improve code clarity. I approve these changes.

@xinyu-intel
Copy link
Contributor

Similar on XPU. Will torch.cuda.memory_reserved more reasonable than torch.cuda.mem_get_info?

@JartX
Copy link
Contributor Author

JartX commented Mar 11, 2026

Hi! @xinyu-intel
The error occurred during VRAM estimation, but it was only triggered by specific occupancy levels or specific models.

For example, Qwen3.5 27B Int4 crashed on 48GB VRAM (RDNA3) due to the estimation logic. In contrast, Qwen3.5 35B Int4 functioned correctly on the same 48GB RDNA3 setup; at a 40,960 context length, there was still nearly 16x available, whereas with the 27B model, only 5x remained. I have not been able to reproduce this behavior on NVIDIA hardware. The specific error reported was a HIP Memory error.
They mentioned that this also happens on XPU. Could you please try adding the case like ROCM to see if it starts up? From there, we could start looking into alternatives like memory info.
What do you think?
Thank you very much for getting back to me!

@xinyu-intel
Copy link
Contributor

Hi! @xinyu-intel The error occurred during VRAM estimation, but it was only triggered by specific occupancy levels or specific models.

For example, Qwen3.5 27B Int4 crashed on 48GB VRAM (RDNA3) due to the estimation logic. In contrast, Qwen3.5 35B Int4 functioned correctly on the same 48GB RDNA3 setup; at a 40,960 context length, there was still nearly 16x available, whereas with the 27B model, only 5x remained. I have not been able to reproduce this behavior on NVIDIA hardware. The specific error reported was a HIP Memory error. They mentioned that this also happens on XPU. Could you please try adding the case like ROCM to see if it starts up? From there, we could start looking into alternatives like memory info. What do you think? Thank you very much for getting back to me!

Hi. From my understanding, it is difficult to estimate the behavior consistently across different models. This is because mem_get_info queries the free memory reported by the device driver, whereas memory_reserved reflects the memory reserved by the PyTorch caching allocator. Since these two values come from different layers of the stack, their relationship is not straightforward. In particular, the behavior is not fully predictable because it depends on how the device driver manages and reports memory. I’m not sure whether memory_reserved is a more reasonable metric, but since graph capture is a PyTorch feature, querying memory statistics from PyTorch seems more straightforward.

@JartX
Copy link
Contributor Author

JartX commented Mar 11, 2026

@xinyu-intel @tjtanaa

RuntimeError: Worker failed with error 'HIP out of memory. Tried to allocate 784.00 MiB. GPU 0 has a total capacity of 23.98 GiB of which 630.00 MiB is free. Of the allocated memory 22.27 GiB is allocated by PyTorch, and 492.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)', please check the stack trace above for the root cause

After torch.cuda.mem_get_info to torch.cuda.memory_reserved

@jikunshang
Copy link
Collaborator

CUDA graph memory profiling is a feature and will enabled in 0.19 by default. I think your code change will always disable this feature on ROCm platform.
My understanding is: torch.cuda.mem_get_info use cudaMemGetInfo to query memory, it's not a torch side standard, hardware vendors are hard to implement exact same functionality and behavior like cuda do(eg, free timing). agree with @xinyu-intel we should try use some torch API, while I don't have thoughts which one is reliable... maybe we should print memory debug log for each graph capture step on different platform.
cc @MatthewBonanni who add this feature in #30515

@JartX
Copy link
Contributor Author

JartX commented Mar 11, 2026

@jikunshang

The problem seems to occur here:

_init_minimal_kv_cache_for_profiling

in gpu_model_runner.py

If I set a 90% maximum usage safety margin in the free RAM calculation, then it works correctly.

We need to try to get the calculation right, as xinyu-intel correctly pointed out:

@xinyu-intel

@tjtanaa
Copy link
Collaborator

tjtanaa commented Mar 16, 2026

@JartX I have noticed this, there are 2 issues in this PR #30515 .

The PR introduced a flag VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 which is supposed to switch on/off the check.

(Critical) However, this line

if cudagraph_memory_estimate > 0:
total_mem = self.init_snapshot.total_memory
current_util = self.cache_config.gpu_memory_utilization
cg_util_delta = cudagraph_memory_estimate / total_mem
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS:
equiv_util = round(current_util - cg_util_delta, 4)
suggested_util = min(
round(current_util + cg_util_delta, 4),
1.0,
)
logger.info(
"CUDA graph memory profiling is enabled "
"(VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). "
"This will become the default in v0.19. "
"The current --gpu-memory-utilization=%.4f is equivalent "
"to --gpu-memory-utilization=%.4f without CUDA graph "
"memory profiling. To maintain the same effective KV "
"cache size as before, increase "
"--gpu-memory-utilization to %.4f.",
current_util,
equiv_util,
suggested_util,
)
else:
suggested_util = min(
round(current_util + cg_util_delta, 4),
1.0,
)
logger.info(
"In v0.19, CUDA graph memory profiling will be enabled "
"by default (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1), "
"which more accurately accounts for CUDA graph memory "
"during KV cache allocation. To try it now, set "
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 and increase "
"--gpu-memory-utilization from %.4f to %.4f to maintain "
"the same effective KV cache size.",
current_util,
suggested_util,
)
it is still reusing the cudagraph_memory_estimate rather than cudagraph_memory_estimate_applied.

Based on the logic flow of

cudagraph_memory_estimate_applied = (
cudagraph_memory_estimate
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
else 0
)

  1. (non-critical it is just log) The property self.cudagraph_memory_estimate should be assigned to cudagraph_memory_estimate_applied since we have this flag VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
    self.cudagraph_memory_estimate = cudagraph_memory_estimate

So, a proper fix for now is to make sure that those cudagraph_memory_estimate_applied is used instead of cudagraph_memory_estimate at those two places.

@MatthewBonanni Could you check if my understanding is correct?

@MatthewBonanni
Copy link
Collaborator

@tjtanaa Both of these are intentional and correct as is.

Line 458 is intentionally supposed to be cudagraph_memory_estimate, not cudagraph_memory_estimate_applied.
We perform the CG memory profiling regardless of whether it's actually used to inform the KV cache size. This is so that the log is printed even when VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS is set to 0, so that we can ease the transition into enabling by default.

Line 419 is also intentionally supposed to be cudagraph_memory_estimate, not cudagraph_memory_estimate_applied, because, as you point out, this is only used to log how close the estimate was to the true size.

@MatthewBonanni
Copy link
Collaborator

MatthewBonanni commented Mar 16, 2026

@JartX could you clarify the following:

  • What is actually causing the ValueError? Can you add a stack trace?

torch.cuda.mem_get_info() and graph_pool_handle() which behave differently on ROCm/HIP and can produce incorrect or negative memory estimates

  • Could you clarify what you mean by this? In what circumstances are they incorrect or negative? Is there a plan to fix this upstream in PyTorch? Is there a preferred way to get these measurements on ROCm?

edit: I see @jikunshang's comment now: #36720 (comment) could you inform what method would be more reliable for other platforms?

@JartX
Copy link
Contributor Author

JartX commented Mar 17, 2026

Not should be work

@JartX
Copy link
Contributor Author

JartX commented Mar 17, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an out-of-memory error on ROCm platforms during worker startup by correcting the CUDA graph memory profiling logic. The change replaces torch.cuda.mem_get_info(), which is unreliable on ROCm, with torch.cuda.memory_allocated() for a more accurate and platform-agnostic measurement of memory usage. The fix is well-targeted and correctly addresses the root cause. The implementation is sound and I have no further comments.

@mergify
Copy link

mergify bot commented Mar 17, 2026

Hi @JartX, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 17, 2026
@MatthewBonanni
Copy link
Collaborator

MatthewBonanni commented Mar 17, 2026

@JartX Replacing with torch.accelerator.memory_allocated won't work because we need to account for memory allocations that happen outside of pytorch (for example, some attention backends allocate their own workspaces). I think the better portable replacement for torch.cuda.mem_get_info would be torch.accelerator.get_memory_info, but I'm not sure if that actually fixes the issue

@JartX JartX requested a review from WoosukKwon as a code owner March 17, 2026 17:38
@JartX
Copy link
Contributor Author

JartX commented Mar 17, 2026

@JartX Replacing with torch.accelerator.memory_allocated won't work because we need to account for memory allocations that happen outside of pytorch (for example, some attention backends allocate their own workspaces). I think the better portable replacement for torch.cuda.mem_get_info would be torch.accelerator.get_memory_info, but I'm not sure if that actually fixes the issue

We're talking about how it no longer crashes on me due to HIP, but that it can crash in other cases too.
There has to be a way to perform the calculation correctly; perhaps someone can shed some light on this. The error has stopped occurring for me. But it must be due to the memory calculation, perhaps a combination of alloc and free real memory?

@MatthewBonanni
Copy link
Collaborator

@JartX re-upping my question from earlier, what is actually causing the ValueError? Can you add a stack trace?

@JartX
Copy link
Contributor Author

JartX commented Mar 17, 2026

@JartX JartX force-pushed the fix/rocm-cudagraph-memory-profiling-startup-oom branch from 988b309 to 5cf8086 Compare March 17, 2026 18:37
@gshtras
Copy link
Collaborator

gshtras commented Mar 17, 2026

This PR also solves (works around) the issue where after #30515 on ROCm MoE model performance would drop by up to 20%

Copy link
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the PR, I think disabling on ROCm is the right approach for now. The fix works in my testing too. Can you update the description to reflect the final state of the PR?

In the future, can you also avoid force-pushing? It wipes out the history which makes the discussion in the PR much harder to parse

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 17, 2026
@MatthewBonanni MatthewBonanni merged commit e8f9dbc into vllm-project:main Mar 17, 2026
55 of 56 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 17, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 17, 2026
@JartX
Copy link
Contributor Author

JartX commented Mar 17, 2026

Gracias por actualizar la solicitud de extracción. Creo que deshabilitarla en ROCm es la solución adecuada por ahora. La corrección también funciona en mis pruebas. ¿Podrías actualizar la descripción para reflejar el estado final de la solicitud de extracción?

En el futuro, ¿podrás evitar también forzar las cosas? Borra el historial, lo que hace que la discusión en el PR sea mucho más difícil de analizar.

I've rewritten and rolled back to the original branch because you were right. I didn't mean to cause any trouble; I apologize if it bothered you.🙏

@MatthewBonanni
Copy link
Collaborator

MatthewBonanni commented Mar 17, 2026

@JartX no problem! It's just good practice for the sake of documentation.

I appreciate the fix!

@JartX JartX deleted the fix/rocm-cudagraph-memory-profiling-startup-oom branch March 17, 2026 22:54
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…cudagraph memory profiling (vllm-project#36720)

Signed-off-by: JartX <sagformas@epdcenter.es>
@gshtras gshtras added this to the v0.18.0 cherry picks milestone Mar 18, 2026
khluu pushed a commit that referenced this pull request Mar 19, 2026
…cudagraph memory profiling (#36720)

Signed-off-by: JartX <sagformas@epdcenter.es>
(cherry picked from commit e8f9dbc)
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…cudagraph memory profiling (vllm-project#36720)

Signed-off-by: JartX <sagformas@epdcenter.es>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants