Skip to content

Commit ce1303a

Browse files
committed
fix: restrict PyTorch memory usage to avoid OOMs
Signed-off-by: ixlmar <[email protected]>
1 parent e831673 commit ce1303a

File tree

5 files changed

+56
-4
lines changed

5 files changed

+56
-4
lines changed

docker/Dockerfile.multi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ RUN bash ./install_mpi4py.sh && rm install_mpi4py.sh
6767
ARG TORCH_INSTALL_TYPE="skip"
6868
COPY docker/common/install_pytorch.sh install_pytorch.sh
6969
RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
70+
#
71+
# NB: PyTorch requires this to be < 1.0
72+
ENV PYTORCH_CUDA_ALLOC_CONF="garbage_collection_threshold:0.99999"
7073

7174
# Install OpenCV with FFMPEG support
7275
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import random
23
from collections.abc import Iterable
34
from typing import Dict, List, Optional
@@ -18,6 +19,7 @@
1819

1920
from ..model_config import ModelConfig
2021
from ..speculative import get_spec_decoder
22+
from .config import PyTorchConfig
2123
from .config_utils import is_mla, is_nemotron_hybrid
2224
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
2325
from .llm_request import ExecutorResponse
@@ -676,3 +678,45 @@ def _try_infer_num_experts(model_config: ModelConfig) -> int:
676678
return 1
677679

678680
return num_experts
681+
682+
683+
def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
684+
# FIXME: PyTorch only uses the garbage_collection_threshold setting
685+
# if a memory fraction is set, cf.
686+
# https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/c10/cuda/CUDACachingAllocator.cpp#L1357
687+
logger.debug("Setting PyTorch memory fraction to 1.0")
688+
torch.cuda.set_per_process_memory_fraction(1.0)
689+
690+
# FIXME: As soon as
691+
# torch.cuda._set_allocator_settings (added in PyTorch 2.8.0-rc1)
692+
# or a similar API is available, the warning below should be removed
693+
# and the allocator GC threshold be set via the new API instead.
694+
torch_allocator_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
695+
torch_mem_threshold_advised = (
696+
torch.cuda.get_allocator_backend() == "native"
697+
and "expandable_segments:True" not in torch_allocator_config)
698+
torch_mem_threshold_set = "garbage_collection_threshold:" in torch_allocator_config
699+
if torch_mem_threshold_advised and not torch_mem_threshold_set:
700+
logger.warning(
701+
"It is recommended to incl. 'garbage_collection_threshold:0.???' or 'backend:cudaMallocAsync'"
702+
" or 'expandable_segments:True' in PYTORCH_CUDA_ALLOC_CONF.")
703+
704+
# NOTE: Even if a memory threshold was not set (cf. warning above), setting a memory
705+
# fraction < 1.0 is beneficial, because
706+
# https://github.com/pytorch/pytorch/blob/5228986c395dc79f90d2a2b991deea1eef188260/c10/cuda/CUDACachingAllocator.cpp#L2719
707+
# and
708+
# https://github.com/pytorch/pytorch/blob/5228986c395dc79f90d2a2b991deea1eef188260/c10/cuda/CUDACachingAllocator.cpp#L1240
709+
# lead PyTorch to release all unused memory before hitting the set fraction. This
710+
# still mitigates OOM, although at a higher performance impact, because it
711+
# effectively resets the allocator cache.
712+
if not pytorch_backend_config._limit_torch_cuda_mem_fraction:
713+
return
714+
mem_reserved = torch.cuda.memory_reserved()
715+
mem_free, mem_total = torch.cuda.mem_get_info()
716+
safety_margin = 32 * 1024**2
717+
mem_torch_max = mem_free + mem_reserved - safety_margin
718+
mem_torch_fraction = mem_torch_max / mem_total
719+
logger.info(
720+
f"Setting PyTorch memory fraction to {mem_torch_fraction} ({mem_torch_max / 1024**3} GiB)"
721+
)
722+
torch.cuda.set_per_process_memory_fraction(mem_torch_fraction)

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ class PyTorchConfig:
8787
enable_min_latency: bool = False
8888
allreduce_strategy: str = "AUTO"
8989

90+
# If true, adjust PyTorch CUDA memory fraction to correspond to the
91+
# total GPU memory minus the statically allocated engine memory.
92+
# If false, set the PyTorch CUDA memory fraction to 1.0.
93+
_limit_torch_cuda_mem_fraction: bool = True
94+
9095

9196
EXETENDED_EXECUTOR_CONFIG_FIELDS = [
9297
'backend',

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from ..attention_backend.interface import AttentionRuntimeFeatures
2121
from ..distributed import MPIDist
2222
from ..speculative import NGramConfig, get_spec_resource_manager
23-
from ._util import (KvCacheCreator, create_py_executor_instance,
24-
instantiate_sampler, is_mla)
23+
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
24+
create_py_executor_instance, instantiate_sampler, is_mla)
2525
from .config import PyTorchConfig
2626
from .config_utils import is_mla
2727
from .model_engine import PyTorchModelEngine
@@ -372,5 +372,7 @@ def create_py_executor(
372372
draft_model_engine, False, sampler, lora_config,
373373
garbage_collection_gen0_threshold)
374374

375+
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
376+
375377
py_executor.start_worker()
376378
return py_executor

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,6 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype
403403
test_e2e.py::test_openai_multi_chat_example SKIP (https://nvbugs/5236980)
404404
test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-70B-FP8-llama-3.1-model/Llama-3.1-70B-Instruct-FP8] SKIP (https://nvbugs/5318059)
405405
test_e2e.py::test_ptp_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct] SKIP (https://nvbugspro.nvidia.com/bug/5324239)
406-
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5318143)
407-
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=True] SKIP (https://nvbugs/5318143)
408406
test_e2e.py::test_ptp_quickstart_advanced[Nemotron-H-8B-Nemotron-H-8B-Base-8K] SKIP (https://nvbugs/5325284)
409407
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5328160)
410408
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5328495)

0 commit comments

Comments
 (0)