Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ def determine_available_memory(self) -> int:
) as profile_result:
self.model_runner.profile_run()

profile_torch_peak = torch.accelerator.memory_stats(self.device).get(
"allocated_bytes.all.peak", 0
)
_mem_stats = current_platform.memory_stats(self.device)
profile_torch_peak = _mem_stats.get("allocated_bytes.all.peak", 0)
profile_torch_reserved_peak = _mem_stats.get("reserved_bytes.all.peak", 0)

# Profile CUDA graph memory if graphs will be captured.
# Skip on ROCm/HIP as graph pool handles and mem_get_info behave
Expand All @@ -398,6 +398,14 @@ def determine_available_memory(self) -> int:
if not self.model_config.enforce_eager and not current_platform.is_rocm():
cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory()

measured_fragmentation = max(
0, int(profile_torch_reserved_peak - profile_torch_peak)
)
logger.info(
"Fragmentation profiling: selected=cold, value=%d MiB",
measured_fragmentation >> 20,
)

# Use the pre-cudagraph torch peak to avoid double-counting.
profile_result.torch_peak_increase = (
profile_torch_peak - profile_result.before_profile.torch_peak
Expand All @@ -422,6 +430,22 @@ def determine_available_memory(self) -> int:
)
self.cudagraph_memory_estimate = cudagraph_memory_estimate

# Fragmentation observed during profiling is typically lower than
# runtime steady state. Use a safety factor with a minimum floor.
fragmentation_multiplier = 2.0
self.fragmentation_buffer = max(
150 * (1 << 20),
int(measured_fragmentation * fragmentation_multiplier),
)
Comment on lines +436 to +439
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 150 * (1 << 20) is a magic number. To improve readability and maintainability, it should be defined as a constant. This value is also duplicated in compile_or_warm_up_model, making it prone to inconsistencies if updated in only one place. Please define it as a shared constant (e.g., _DEFAULT_FRAGMENTATION_BUFFER_BYTES) and use it in both locations.

fragmentation_buffer = self.fragmentation_buffer
logger.info(
"Memory profiling: fragmentation=%.2f MiB, multiplier=%.2f, "
"fragmentation_buffer=%.2f MiB",
measured_fragmentation / (1 << 20),
fragmentation_multiplier,
fragmentation_buffer / (1 << 20),
)

free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
Expand All @@ -438,6 +462,7 @@ def determine_available_memory(self) -> int:
self.requested_memory
- profile_result.non_kv_cache_memory
- cudagraph_memory_estimate_applied
- fragmentation_buffer
)

unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
Expand Down Expand Up @@ -634,10 +659,13 @@ def compile_or_warm_up_model(self) -> float:
# Users may want fine-grained control to specify kv cache
# memory size.

# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
# Use the same fragmentation buffer computed during profiling
# (in determine_available_memory) for consistent --kv-cache-memory
# suggestions. This ensures users who follow the suggestion get
# the same safety margin as the auto-profiling path.
redundancy_buffer_memory = getattr(
self, "fragmentation_buffer", 150 * (1 << 20)
)
Comment on lines +666 to +668
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line duplicates the magic number 150 * (1 << 20) from determine_available_memory. Using a shared constant for this value would prevent potential inconsistencies between the auto-profiling path and the manual configuration suggestion, which is a key goal of this PR.

non_kv_cache_memory = (
self.model_runner.model_memory_usage
+ self.peak_activation_memory
Expand Down
Loading