Skip to content

[W.I.P] fragmentation_buffer in profiling#37428

Open
panpan0000 wants to merge 2 commits intovllm-project:mainfrom
panpan0000:fragmentation_buffer_profiling
Open

[W.I.P] fragmentation_buffer in profiling#37428
panpan0000 wants to merge 2 commits intovllm-project:mainfrom
panpan0000:fragmentation_buffer_profiling

Conversation

@panpan0000
Copy link
Contributor

@panpan0000 panpan0000 commented Mar 18, 2026

STILL UNDER TESTING.......

Purpose

As we know, vLLM may suffer from OOM under high stress , although vLLM code has tried best in warmup & profiling during starting, but there're still gap.

Sometimes, user want highest throughput (like benchmarking) for as large as possible kv-cache space.
But most of time, we are under SLO KPI, we need vLLM to keep alive as much as possible, not OOM, but alive.

This PR adds a fragmentation-aware safety buffer to KV cache budgeting, instead of just hardcode 150MB,
as another fix for #37420

Background

(1) Why OOM?

Profiling runs in a clean allocator state, so memory looks better than later real running. The fragment will grow after some mixed batches and become worse in higher stress batch size. When a forward step may need one large contiguous block (for example, ~384 MiB), but only smaller pieces are available(329.87MB), so CUDA OOM happens.

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 384.00 MiB.
GPU has 19.53 GiB total, 329.87 MiB free.
PyTorch allocated: 18.40 GiB
Reserved but unallocated: 402.27 MiB

(2) What's the fragmentation?

Fragmentation is reserved - allocated: memory reserved by PyTorch but not used by live tensors. This free space is often split into many diff size small blocks in the allocator pool. The bytes are there, but not as one contiguous block, so a large activation allocation can still fail.

fragmentation = reserved - allocated = free blocks

reserved pool by PyTorch:

Block:  [1K] [2K] [4K] [4K] [8K] [16K]  

after a while, some blocks are used with underline "~"
Block:  [1K] [2K] [4K] [4K] [8K] [16K]  
              ~~           ~~            ~~    ~~

when requesting  8K, no enough continuous block avail and OOM.

Proposed Change

Co-author by AI

1: determine_available_memory() record the fragmentation

measured_fragmentation = max(0, profile_torch_reserved_peak - profile_torch_peak)

2: mutiple it X 1~2 factor, and subtract it from kv-cache space.

# Apply buffer to actual KV cache allocation
self.available_kv_cache_memory_bytes = (
     ....
    - self.fragmentation_buffer              # ←    here
)

Cost: KV cache shrinks a little bit. Negligible throughput impact.

Abandoned Solutions

1) Try-catch torch.cuda.OutOfMemoryError with graceful recovery

Proposal: Wrap execute_model() in a try-catch. On OOM, call empty_cache(), reduce the token budget, and retry.

Why rejected:

  • OOM inside torch.compile/inductor generated code may leave compiled graph internal buffers in an inconsistent state — subsequent calls may produce silent incorrect results, which is worse than crashing
  • Forward pass partially updates KV cache before OOM. Determining exactly which layers/positions were written requires per-layer tracking that doesn't exist. Incomplete invalidation → garbage decode output
  • For multi-GPU (TP > 1), one rank OOMing while others don't causes NCCL collective operations to deadlock. Recovery requires coordinated abort across all ranks
  • OOM during CUDA graph replay is undefined behavior

2) Runtime pre-execution memory check

Proposal: Before each execute_model(), call torch.cuda.mem_get_info() to check free memory. If insufficient, signal the scheduler to reduce the batch.

Why rejected:

  • mem_get_info() returns free memory from CUDA driver's perspective, but PyTorch's caching allocator holds reserved-but-reusable blocks that appear "used" to CUDA. This makes the check overly conservative, rejecting batches that would fit fine
  • mem_get_info() is a CUDA driver API call that may trigger CPU-GPU synchronization, breaking the async scheduling pipeline in vLLM v1 and adding latency to every forward step
  • Activation memory estimation (activation_per_token * num_tokens) is not linear — torch.compile/inductor generates different buffer patterns for different symbolic shape combinations

3) control token-budget according to realtime mem usage

Proposal: Periodically check GPU memory usage and dynamically scale max_num_scheduled_tokens.

Why rejected:

  • Same mem_get_info() inaccuracy problem as above (CUDA free ≠ PyTorch available)
  • Adds significant complexity to the scheduler with unclear benefits over simply sizing the buffer correctly at startup
  • Monitoring frequency introduces a tradeoff between responsiveness and overhead
  • Over-engineering for a problem that can be solved with a better buffer calculation

4) Percentage buffer as total_gpu_memory * X %

Example: redundancy_buffer_memory = max(150 MiB, int(total_memory * 0.02))

Why rejected :

  • Fragmentation depends on allocation patterns (model architecture, batch composition, tensor size distribution), not GPU total memory
  • An 80 GiB A100 running a 70B model may have 500 MiB fragmentation; a 16 GiB T4 running a 7B model may also have 400 MiB. A fixed percentage of GPU memory has no physical relationship to the actual fragmentation
  • "Why 2%?" has no defensible answer — it's a magic number with no derivation

5) Make max_concurrency a hard admission cap instead of just a log message

Solution: The logged "Maximum concurrency for 10,000 tokens per request: 14.79x" should be enforced as a hard limit on concurrent requests.

Why rejected:

  • This addresses KV cache capacity, not activation memory. The OOM occurred during the forward pass (activation allocation), not during KV cache allocation
  • KV cache was only 67.8% utilized at OOM time — plenty of KV headroom remained
  • Would unnecessarily limit throughput in the common case where KV is the bottleneck

Fragment Observation


| Workload | fragmentation size|

| Profiling | 1219 MiB |
| do 3~4 times of Profiling | no chnage |
| --num-prompts 10 (低并发) | 767 MiB|
| --num-prompts 20 (高并发) | 1825 MiB |

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@panpan0000 panpan0000 requested a review from njhill as a code owner March 18, 2026 12:41
@mergify mergify bot added the v1 label Mar 18, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a dynamic fragmentation buffer to prevent out-of-memory errors during model execution. The buffer is calculated based on memory fragmentation observed during profiling, providing a more robust safety margin than the previous hardcoded value. The changes correctly apply this buffer to both automatic KV cache allocation and the suggested value for manual configuration. My main feedback is to refactor the duplicated magic number for the minimum buffer size into a constant to improve maintainability and ensure consistency between the two code paths.

Comment on lines +433 to +436
self.fragmentation_buffer = max(
150 * (1 << 20),
int(measured_fragmentation * 2),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 150 * (1 << 20) is a magic number. To improve readability and maintainability, it should be defined as a constant. This value is also duplicated in compile_or_warm_up_model, making it prone to inconsistencies if updated in only one place. Please define it as a shared constant (e.g., _DEFAULT_FRAGMENTATION_BUFFER_BYTES) and use it in both locations.

Comment on lines +666 to +668
redundancy_buffer_memory = getattr(
self, "fragmentation_buffer", 150 * (1 << 20)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line duplicates the magic number 150 * (1 << 20) from determine_available_memory. Using a shared constant for this value would prevent potential inconsistencies between the auto-profiling path and the manual configuration suggestion, which is a key goal of this PR.

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
@mergify
Copy link

mergify bot commented Mar 20, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @panpan0000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant