Skip to content

Commit f225f5c

Browse files
authored
[nvbugs-5318143] fix: restrict PyTorch memory usage to avoid OOMs (#5964)
Signed-off-by: ixlmar <[email protected]>
1 parent f277afd commit f225f5c

File tree

5 files changed

+56
-4
lines changed

5 files changed

+56
-4
lines changed

docker/Dockerfile.multi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ RUN GITHUB_MIRROR=$GITHUB_MIRROR bash ./install_mpi4py.sh && rm install_mpi4py.s
6666
ARG TORCH_INSTALL_TYPE="skip"
6767
COPY docker/common/install_pytorch.sh install_pytorch.sh
6868
RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
69+
#
70+
# NB: PyTorch requires this to be < 1.0
71+
ENV PYTORCH_CUDA_ALLOC_CONF="garbage_collection_threshold:0.99999"
6972

7073
# Install OpenCV with FFMPEG support
7174
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import random
23
from collections.abc import Iterable
34
from typing import Dict, List, Optional
@@ -18,6 +19,7 @@
1819

1920
from ..model_config import ModelConfig
2021
from ..speculative import get_spec_decoder
22+
from .config import PyTorchConfig
2123
from .config_utils import is_mla, is_nemotron_hybrid
2224
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
2325
from .llm_request import ExecutorResponse
@@ -718,3 +720,45 @@ def _try_infer_num_experts(model_config: ModelConfig) -> int:
718720
return 1
719721

720722
return num_experts
723+
724+
725+
def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
726+
# FIXME: PyTorch only uses the garbage_collection_threshold setting
727+
# if a memory fraction is set, cf.
728+
# https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/c10/cuda/CUDACachingAllocator.cpp#L1357
729+
logger.debug("Setting PyTorch memory fraction to 1.0")
730+
torch.cuda.set_per_process_memory_fraction(1.0)
731+
732+
# FIXME: As soon as
733+
# torch.cuda._set_allocator_settings (added in PyTorch 2.8.0-rc1)
734+
# or a similar API is available, the warning below should be removed
735+
# and the allocator GC threshold be set via the new API instead.
736+
torch_allocator_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
737+
torch_mem_threshold_advised = (
738+
torch.cuda.get_allocator_backend() == "native"
739+
and "expandable_segments:True" not in torch_allocator_config)
740+
torch_mem_threshold_set = "garbage_collection_threshold:" in torch_allocator_config
741+
if torch_mem_threshold_advised and not torch_mem_threshold_set:
742+
logger.warning(
743+
"It is recommended to incl. 'garbage_collection_threshold:0.???' or 'backend:cudaMallocAsync'"
744+
" or 'expandable_segments:True' in PYTORCH_CUDA_ALLOC_CONF.")
745+
746+
# NOTE: Even if a memory threshold was not set (cf. warning above), setting a memory
747+
# fraction < 1.0 is beneficial, because
748+
# https://github.com/pytorch/pytorch/blob/5228986c395dc79f90d2a2b991deea1eef188260/c10/cuda/CUDACachingAllocator.cpp#L2719
749+
# and
750+
# https://github.com/pytorch/pytorch/blob/5228986c395dc79f90d2a2b991deea1eef188260/c10/cuda/CUDACachingAllocator.cpp#L1240
751+
# lead PyTorch to release all unused memory before hitting the set fraction. This
752+
# still mitigates OOM, although at a higher performance impact, because it
753+
# effectively resets the allocator cache.
754+
if not pytorch_backend_config._limit_torch_cuda_mem_fraction:
755+
return
756+
mem_reserved = torch.cuda.memory_reserved()
757+
mem_free, mem_total = torch.cuda.mem_get_info()
758+
safety_margin = 32 * 1024**2
759+
mem_torch_max = mem_free + mem_reserved - safety_margin
760+
mem_torch_fraction = mem_torch_max / mem_total
761+
logger.info(
762+
f"Setting PyTorch memory fraction to {mem_torch_fraction} ({mem_torch_max / 1024**3} GiB)"
763+
)
764+
torch.cuda.set_per_process_memory_fraction(mem_torch_fraction)

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ class PyTorchConfig:
9292

9393
force_dynamic_quantization: bool = False
9494

95+
# If true, adjust PyTorch CUDA memory fraction to correspond to the
96+
# total GPU memory minus the statically allocated engine memory.
97+
# If false, set the PyTorch CUDA memory fraction to 1.0.
98+
_limit_torch_cuda_mem_fraction: bool = True
99+
95100

96101
EXETENDED_EXECUTOR_CONFIG_FIELDS = [
97102
'backend',

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from ..attention_backend.interface import AttentionRuntimeFeatures
2121
from ..distributed import MPIDist
2222
from ..speculative import get_spec_drafter, get_spec_resource_manager
23-
from ._util import (KvCacheCreator, create_py_executor_instance,
24-
instantiate_sampler, is_mla)
23+
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
24+
create_py_executor_instance, instantiate_sampler, is_mla)
2525
from .config import PyTorchConfig
2626
from .config_utils import is_mla
2727
from .model_engine import PyTorchModelEngine
@@ -432,5 +432,7 @@ def create_py_executor(
432432
garbage_collection_gen0_threshold,
433433
)
434434

435+
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
436+
435437
py_executor.start_worker()
436438
return py_executor

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,6 @@ perf/test_perf.py::test_perf[mamba_130m-bench-float16-input_output_len:128,128]
372372
perf/test_perf.py::test_perf[bert_large-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411)
373373
perf/test_perf.py::test_perf[roberta_base-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411)
374374
test_e2e.py::test_openai_multi_chat_example SKIP (https://nvbugs/5236980)
375-
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5318143)
376-
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=True] SKIP (https://nvbugs/5318143)
377375
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5328160)
378376
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5328495)
379377
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5322354)

0 commit comments

Comments
 (0)