From e1dacb359582b0b25ff2b06e4c44e1a4ab7cfbbb Mon Sep 17 00:00:00 2001 From: Peter Pan Date: Wed, 18 Mar 2026 17:03:41 +0800 Subject: [PATCH 1/2] fragmentation_buffer in profiling Signed-off-by: Peter Pan --- vllm/v1/worker/gpu_worker.py | 42 ++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d101edc18100..e51743e5202e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -387,9 +387,9 @@ def determine_available_memory(self) -> int: ) as profile_result: self.model_runner.profile_run() - profile_torch_peak = torch.accelerator.memory_stats(self.device).get( - "allocated_bytes.all.peak", 0 - ) + _mem_stats = current_platform.memory_stats(self.device) + profile_torch_peak = _mem_stats.get("allocated_bytes.all.peak", 0) + profile_torch_reserved_peak = _mem_stats.get("reserved_bytes.all.peak", 0) # Profile CUDA graph memory if graphs will be captured. # Skip on ROCm/HIP as graph pool handles and mem_get_info behave @@ -422,6 +422,30 @@ def determine_available_memory(self) -> int: ) self.cudagraph_memory_estimate = cudagraph_memory_estimate + # Measure caching allocator fragmentation: the gap between what + # PyTorch reserved from CUDA vs what it actually allocated. + # At runtime, fragmentation grows beyond profiling (freshly initialized + # allocator is best-case). Use 2x as safety factor, minimum 150 MiB. + measured_fragmentation = max( + 0, + profile_torch_reserved_peak - profile_torch_peak, + ) + self.fragmentation_buffer = max( + 150 * (1 << 20), + int(measured_fragmentation * 2), + ) + fragmentation_buffer = self.fragmentation_buffer + logger.info( + "Memory profiling: allocated_peak=%.2f MiB, " + "reserved_peak=%.2f MiB, " + "measured_fragmentation=%.2f MiB, " + "fragmentation_buffer=%.2f MiB", + profile_torch_peak / (1 << 20), + profile_torch_reserved_peak / (1 << 20), + measured_fragmentation / (1 << 20), + fragmentation_buffer / (1 << 20), + ) + free_gpu_memory = profile_result.after_profile.free_memory # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. @@ -438,6 +462,7 @@ def determine_available_memory(self) -> int: self.requested_memory - profile_result.non_kv_cache_memory - cudagraph_memory_estimate_applied + - fragmentation_buffer ) unrequested_memory = self.init_snapshot.free_memory - self.requested_memory @@ -634,10 +659,13 @@ def compile_or_warm_up_model(self) -> float: # Users may want fine-grained control to specify kv cache # memory size. - # empirically observed that the memory profiling may - # slightly underestimate the memory consumption. - # So leave a small buffer (=150MiB) to avoid OOM. - redundancy_buffer_memory = 150 * (1 << 20) + # Use the same fragmentation buffer computed during profiling + # (in determine_available_memory) for consistent --kv-cache-memory + # suggestions. This ensures users who follow the suggestion get + # the same safety margin as the auto-profiling path. + redundancy_buffer_memory = getattr( + self, "fragmentation_buffer", 150 * (1 << 20) + ) non_kv_cache_memory = ( self.model_runner.model_memory_usage + self.peak_activation_memory From 2b3b88565459a1722438ce923bfbaee2c2abb071 Mon Sep 17 00:00:00 2001 From: Peter Pan Date: Fri, 20 Mar 2026 09:40:04 +0800 Subject: [PATCH 2/2] refine Signed-off-by: Peter Pan --- vllm/v1/worker/gpu_worker.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index e51743e5202e..2883ea7f9104 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -398,6 +398,14 @@ def determine_available_memory(self) -> int: if not self.model_config.enforce_eager and not current_platform.is_rocm(): cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory() + measured_fragmentation = max( + 0, int(profile_torch_reserved_peak - profile_torch_peak) + ) + logger.info( + "Fragmentation profiling: selected=cold, value=%d MiB", + measured_fragmentation >> 20, + ) + # Use the pre-cudagraph torch peak to avoid double-counting. profile_result.torch_peak_increase = ( profile_torch_peak - profile_result.before_profile.torch_peak @@ -422,27 +430,19 @@ def determine_available_memory(self) -> int: ) self.cudagraph_memory_estimate = cudagraph_memory_estimate - # Measure caching allocator fragmentation: the gap between what - # PyTorch reserved from CUDA vs what it actually allocated. - # At runtime, fragmentation grows beyond profiling (freshly initialized - # allocator is best-case). Use 2x as safety factor, minimum 150 MiB. - measured_fragmentation = max( - 0, - profile_torch_reserved_peak - profile_torch_peak, - ) + # Fragmentation observed during profiling is typically lower than + # runtime steady state. Use a safety factor with a minimum floor. + fragmentation_multiplier = 2.0 self.fragmentation_buffer = max( 150 * (1 << 20), - int(measured_fragmentation * 2), + int(measured_fragmentation * fragmentation_multiplier), ) fragmentation_buffer = self.fragmentation_buffer logger.info( - "Memory profiling: allocated_peak=%.2f MiB, " - "reserved_peak=%.2f MiB, " - "measured_fragmentation=%.2f MiB, " + "Memory profiling: fragmentation=%.2f MiB, multiplier=%.2f, " "fragmentation_buffer=%.2f MiB", - profile_torch_peak / (1 << 20), - profile_torch_reserved_peak / (1 << 20), measured_fragmentation / (1 << 20), + fragmentation_multiplier, fragmentation_buffer / (1 << 20), )