Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions vllm/utils/mem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
assert max_shared_mem > 0, "max_shared_mem cannot be zero"
return int(max_shared_mem)


Expand Down Expand Up @@ -154,12 +154,16 @@ class MemoryProfilingResult:
non_kv_cache_memory: int = 0
torch_peak_increase: int = 0
non_torch_increase: int = 0
weights_memory: float = 0
weights_memory: int = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0

def __post_init__(self) -> None:
device = self.before_create.device_

self.before_profile = MemorySnapshot(device=device, auto_measure=False)
self.after_profile = MemorySnapshot(device=device, auto_measure=False)

def __repr__(self) -> str:
return (
f"Memory profiling takes {self.profile_time:.2f} seconds. "
Expand All @@ -175,9 +179,12 @@ def __repr__(self) -> str:

@contextlib.contextmanager
def memory_profiling(
baseline_snapshot: MemorySnapshot, weights_memory: int
baseline_snapshot: MemorySnapshot,
weights_memory: int = 0,
) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
"""
Memory profiling context manager.

baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
Expand Down Expand Up @@ -217,21 +224,24 @@ def memory_profiling(
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)

The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The memory used for loading weights (a.) is directly given from the
argument `weights_memory`.

The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
during profiling gives (b.).

The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
The increase of `non_torch_memory` from creating the current vLLM instance
until after profiling to get (c.).
"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

result = MemoryProfilingResult()
torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)

result.before_create = baseline_snapshot
# the part of memory used for holding the model weights
result.weights_memory = weights_memory
result = MemoryProfilingResult(
before_create=baseline_snapshot,
# the part of memory used for holding the model weights
weights_memory=weights_memory,
)

result.before_profile.measure()

Expand All @@ -252,4 +262,4 @@ def memory_profiling(
peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = (
non_torch_memory + peak_activation_memory + result.weights_memory
) # noqa
)