Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/Dockerfile.multi
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ RUN GITHUB_MIRROR=$GITHUB_MIRROR bash ./install_mpi4py.sh && rm install_mpi4py.s
ARG TORCH_INSTALL_TYPE="skip"
COPY docker/common/install_pytorch.sh install_pytorch.sh
RUN bash ./install_pytorch.sh $TORCH_INSTALL_TYPE && rm install_pytorch.sh
#
# NB: PyTorch requires this to be < 1.0
ENV PYTORCH_CUDA_ALLOC_CONF="garbage_collection_threshold:0.99999"

# Install OpenCV with FFMPEG support
RUN pip3 uninstall -y opencv && rm -rf /usr/local/lib/python3*/dist-packages/cv2/
Expand Down
44 changes: 44 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
from collections.abc import Iterable
from typing import Dict, List, Optional
Expand All @@ -18,6 +19,7 @@

from ..model_config import ModelConfig
from ..speculative import get_spec_decoder
from .config import PyTorchConfig
from .config_utils import is_mla, is_nemotron_hybrid
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .llm_request import ExecutorResponse
Expand Down Expand Up @@ -718,3 +720,45 @@ def _try_infer_num_experts(model_config: ModelConfig) -> int:
return 1

return num_experts


def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
# FIXME: PyTorch only uses the garbage_collection_threshold setting
# if a memory fraction is set, cf.
# https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/c10/cuda/CUDACachingAllocator.cpp#L1357
logger.debug("Setting PyTorch memory fraction to 1.0")
torch.cuda.set_per_process_memory_fraction(1.0)

# FIXME: As soon as
# torch.cuda._set_allocator_settings (added in PyTorch 2.8.0-rc1)
# or a similar API is available, the warning below should be removed
# and the allocator GC threshold be set via the new API instead.
torch_allocator_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
torch_mem_threshold_advised = (
torch.cuda.get_allocator_backend() == "native"
and "expandable_segments:True" not in torch_allocator_config)
torch_mem_threshold_set = "garbage_collection_threshold:" in torch_allocator_config
if torch_mem_threshold_advised and not torch_mem_threshold_set:
logger.warning(
"It is recommended to incl. 'garbage_collection_threshold:0.???' or 'backend:cudaMallocAsync'"
" or 'expandable_segments:True' in PYTORCH_CUDA_ALLOC_CONF.")

# NOTE: Even if a memory threshold was not set (cf. warning above), setting a memory
# fraction < 1.0 is beneficial, because
# https://github.com/pytorch/pytorch/blob/5228986c395dc79f90d2a2b991deea1eef188260/c10/cuda/CUDACachingAllocator.cpp#L2719
# and
# https://github.com/pytorch/pytorch/blob/5228986c395dc79f90d2a2b991deea1eef188260/c10/cuda/CUDACachingAllocator.cpp#L1240
# lead PyTorch to release all unused memory before hitting the set fraction. This
# still mitigates OOM, although at a higher performance impact, because it
# effectively resets the allocator cache.
if not pytorch_backend_config._limit_torch_cuda_mem_fraction:
return
mem_reserved = torch.cuda.memory_reserved()
mem_free, mem_total = torch.cuda.mem_get_info()
safety_margin = 32 * 1024**2
mem_torch_max = mem_free + mem_reserved - safety_margin
mem_torch_fraction = mem_torch_max / mem_total
logger.info(
f"Setting PyTorch memory fraction to {mem_torch_fraction} ({mem_torch_max / 1024**3} GiB)"
)
torch.cuda.set_per_process_memory_fraction(mem_torch_fraction)
5 changes: 5 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class PyTorchConfig:

force_dynamic_quantization: bool = False

# If true, adjust PyTorch CUDA memory fraction to correspond to the
# total GPU memory minus the statically allocated engine memory.
# If false, set the PyTorch CUDA memory fraction to 1.0.
_limit_torch_cuda_mem_fraction: bool = True


EXETENDED_EXECUTOR_CONFIG_FIELDS = [
'backend',
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from ..attention_backend.interface import AttentionRuntimeFeatures
from ..distributed import MPIDist
from ..speculative import get_spec_drafter, get_spec_resource_manager
from ._util import (KvCacheCreator, create_py_executor_instance,
instantiate_sampler, is_mla)
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
create_py_executor_instance, instantiate_sampler, is_mla)
from .config import PyTorchConfig
from .config_utils import is_mla
from .model_engine import PyTorchModelEngine
Expand Down Expand Up @@ -432,5 +432,7 @@ def create_py_executor(
garbage_collection_gen0_threshold,
)

_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)

py_executor.start_worker()
return py_executor
2 changes: 0 additions & 2 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,6 @@ perf/test_perf.py::test_perf[mamba_130m-bench-float16-input_output_len:128,128]
perf/test_perf.py::test_perf[bert_large-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411)
perf/test_perf.py::test_perf[roberta_base-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411)
test_e2e.py::test_openai_multi_chat_example SKIP (https://nvbugs/5236980)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5318143)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=True] SKIP (https://nvbugs/5318143)
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5328160)
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5328495)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5351130)
Expand Down