Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions vllm/device_allocator/cumem.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,6 @@ def get_instance() -> "CuMemAllocator":
return CuMemAllocator.instance

def __init__(self):
conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
assert "expandable_segments:True" not in conf, (
"Expandable segments are not compatible with memory pool. "
"Please track https://github.com/pytorch/pytorch/issues/147851 "
"for the latest updates."
)

self.pointer_to_data: dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: dict[str, Any] = {}
Expand Down Expand Up @@ -264,34 +257,49 @@ def use_memory_pool(self, tag: str | None = None):

assert isinstance(tag, str)

# Expandable segments are incompatible with the memory pool used for
# sleep mode (see https://github.com/pytorch/pytorch/issues/147851).
# If the user has enabled expandable segments via
# PYTORCH_CUDA_ALLOC_CONF, temporarily disable them for the duration
# of the memory pool context and restore on exit.
conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
expandable_was_enabled = "expandable_segments:True" in conf
if expandable_was_enabled:
torch.cuda.memory._set_allocator_settings("expandable_segments:False")

old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(
self.python_malloc_callback, self.python_free_callback
) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released, e.g. in online quantization,
# where the model is created in higher precision, and then
# quantized in lower precision.
# Find all unused allocations and manually release them.
# TODO: we should expose `empty_cache` method in the memory pool.
# TODO: ask for help from PyTorch team to expose this method.
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
unmap_and_release(handle)
try:
with use_memory_pool_with_allocator(
self.python_malloc_callback, self.python_free_callback
) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator
# and the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released, e.g. in online
# quantization, where the model is created in higher
# precision, and then quantized in lower precision.
# Find all unused allocations and manually release them.
# TODO: we should expose `empty_cache` method in the memory
# pool.
# TODO: ask for help from PyTorch team to expose this method.
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
unmap_and_release(handle)
finally:
self.current_tag = old_tag
if expandable_was_enabled:
torch.cuda.memory._set_allocator_settings("expandable_segments:True")

def get_current_usage(self) -> int:
"""
Expand Down
Loading