Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
64428ce
[https://nvbugs/5606268][fix] Fix program exit segment fault triggere…
yunruis Nov 3, 2025
77da54d
[https://nvbugs/5608930][fix] Unwaive test 5608930 (#8831)
sunnyqgg Nov 3, 2025
8519383
[https://nvbugs/5461796][fix] Unwaive test test_llmapi_speculative_de…
sunnyqgg Nov 3, 2025
5d9e845
[https://nvbugs/5521253][fix] Enable Gemma3 12B & 27B on SM100 (#8666)
brb-nv Nov 3, 2025
3ab3080
[https://nvbugs/5606266][test] move qwen3 multi-node test to the qa l…
Superjomn Nov 4, 2025
86c267d
[https://nvbugs/5569754][chore] Adjust max batch size to prevent OOM …
JunyiXu-nv Nov 4, 2025
0224673
[https://nvbugs/5606136][fix] Fix torch.onnx.export with pytorch upgr…
SimengLiu-nv Nov 4, 2025
b08efed
[https://nvbugs/5634220][fix] Add developer guide back and fix some i…
nv-guomingz Nov 5, 2025
69aaa40
[https://nvbugs/5467531][fix] Fix moe test and wide ep fake impl (#8883)
liji-nv Nov 6, 2025
dd316dd
[https://nvbugs/5636946][fix] Update test model (#8993)
crazydemo Nov 7, 2025
3e20488
[None][doc] Replace the relative links with absolute links in README.…
nv-guomingz Nov 7, 2025
7a3dc48
[https://nvbugs/5575920][fix] Fix cublas/cublasLt handle creation mem…
dominicshanshan Nov 7, 2025
a8fd013
[TRTLLM-9073][doc] Add the missing content for model support section …
nv-guomingz Nov 10, 2025
d0e1ae7
[https://nvbugs/5284463][fix] fix ada fp8 group gemm lacks shared mem…
inocsin Nov 11, 2025
f5c51db
[https://nvbugs/5570575][fix] : Use less kv cache memory on SM120 (#9…
peaceh-nv Nov 11, 2025
3761774
[https://nvbugs/5628204][fix] Stop token IDs - fast path optimization…
moraxu Nov 13, 2025
6445b3e
[TRTLLM-7971][doc] Doc update for multimodal in v1.1 (#9015)
chang-l Nov 13, 2025
93b6805
[https://nvbugs/5652552][fix] Log the llm args (#9119)
leslie-fang25 Nov 14, 2025
3d17bbb
[https://nvbugs/5568836][fix] Skip keyword matching for Gemma3 e2e te…
brb-nv Nov 14, 2025
f5a9bb3
[TRTLLM-9159][doc] Add KV Connector docs (#9043)
Shunkangz Nov 17, 2025
4b66994
[https://nvbugs/5649826][fix] Unwaive test test_llm_commandr_plus_4gp…
sunnyqgg Nov 17, 2025
00eb501
[https://nvbugs/5461796][fix] Unwaive and extend time for test_llmapi…
sunnyqgg Nov 18, 2025
2008c75
[TRTLLM-9092][doc] Add a pre-quantized example in quick start guide (…
QiJune Nov 19, 2025
5cd1b63
[https://nvbugs/5648685][fix] Fix openAI server waiting time to avoid…
dominicshanshan Nov 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Serverless TensorRT LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co

TensorRT LLM is an open-sourced library for optimizing Large Language Model (LLM) inference. It provides state-of-the-art optimizations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, [FP4](https://www.nvidia.com/en-us/data-center/technologies/blackwell-architecture/), INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ...), speculative decoding, and much more, to perform inference efficiently on NVIDIA GPUs.

[Architected on PyTorch](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/torch/arch_overview.md), TensorRT LLM provides a high-level Python [LLM API](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html#llm-api) that supports a wide range of inference setups - from single-GPU to multi-GPU or multi-node deployments. It includes built-in support for various parallelism strategies and advanced features. The LLM API integrates seamlessly with the broader inference ecosystem, including NVIDIA [Dynamo](https://github.com/ai-dynamo/dynamo) and the [Triton Inference Server](https://github.com/triton-inference-server/server).
[Architected on PyTorch](https://github.com/NVIDIA/TensorRT-LLM/blob/release/1.1/docs/source/developer-guide/overview.md), TensorRT LLM provides a high-level Python [LLM API](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html#llm-api) that supports a wide range of inference setups - from single-GPU to multi-GPU or multi-node deployments. It includes built-in support for various parallelism strategies and advanced features. The LLM API integrates seamlessly with the broader inference ecosystem, including NVIDIA [Dynamo](https://github.com/ai-dynamo/dynamo) and the [Triton Inference Server](https://github.com/triton-inference-server/server).

TensorRT LLM is designed to be modular and easy to modify. Its PyTorch-native architecture allows developers to experiment with the runtime or extend functionality. Several popular models are also pre-defined and can be customized using [native PyTorch code](./tensorrt_llm/_torch/models/modeling_deepseekv3.py), making it easy to adapt the system to specific needs.

Expand Down
10 changes: 10 additions & 0 deletions cpp/kernels/fmha_v2/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6398,6 +6398,16 @@ def enumerate_kernels():
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.input_layout != InputLayout.SEPARATE_Q_K_V)
# Gemma3 VL support.
or (kspec.sm == 100
and kspec.dtype in ['fp16', 'bf16', 'fp16_fp32', 'e4m3', 'e4m3_fp32']
and kspec.head_size == 72
and kspec.head_size_v == 0
and kspec.sage_block_sizes is None
and kspec.version == 2
and kspec.cross_mha == False
and kspec.flash_attention == True
and kspec.input_layout != InputLayout.SEPARATE_Q_K_V)
# Deepseek MLA (generation 576/512 paged)
or (kspec.sm in [90, 100, 120]
and kspec.dtype in ['bf16', 'e4m3_fp32']
Expand Down
106 changes: 94 additions & 12 deletions cpp/tensorrt_llm/common/opUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,24 @@ class PerCudaCtxPerThreadSingletonCreator
PerCudaCtxPerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter)
: mCreator{std::move(creator)}
, mDeleter{std::move(deleter)}
, mObservers{new std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>()}
{
}

~PerCudaCtxPerThreadSingletonCreator()
{
std::lock_guard<std::mutex> lk{mMutex};
delete mObservers;
mObservers = nullptr;
}

std::shared_ptr<T> operator()()
{
std::lock_guard<std::mutex> lk{mMutex};
CUcontext ctx{getCurrentCudaCtx()};
std::thread::id thread = std::this_thread::get_id();
auto const key = std::make_tuple(ctx, thread);
std::shared_ptr<T> result = mObservers[key].lock();
std::shared_ptr<T> result = (*mObservers)[key].lock();
if (result == nullptr)
{
TLLM_LOG_TRACE("creating singleton instance for CUDA context %lu and thread %lu", ctx, thread);
Expand All @@ -202,6 +210,11 @@ class PerCudaCtxPerThreadSingletonCreator
}
mDeleter(obj);

if (mObservers == nullptr)
{
return;
}

// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std::shared_ptr<T> observedObjHolder; // Delay destroy to avoid dead lock.
Expand All @@ -210,17 +223,18 @@ class PerCudaCtxPerThreadSingletonCreator
// thread just before we lock mMutex. We can't infer that the observer is stale from the fact that
// obj is destroyed, because shared_ptr ref-count checking and observer removing are not in one
// atomic operation, and the observer may be changed to observe another instance.
if (mObservers.find(key) == mObservers.end())
auto it = mObservers->find(key);
if (it == mObservers->end())
{
return;
}
observedObjHolder = mObservers.at(key).lock();
observedObjHolder = it->second.lock();
if (observedObjHolder == nullptr)
{
mObservers.erase(key);
mObservers->erase(it);
}
}};
mObservers.at(key) = result;
(*mObservers)[key] = result;
}
else
{
Expand All @@ -235,24 +249,78 @@ class PerCudaCtxPerThreadSingletonCreator
mutable std::mutex mMutex;
// CUDA resources are per-context and per-thread.
using CacheKey = std::tuple<CUcontext, std::thread::id>;
std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>> mObservers;
std::unordered_map<CacheKey, std::weak_ptr<T>, hash<CacheKey>>* mObservers;
};

// Structure to hold memory information
struct MemoryInfo
{
size_t free_mb;
size_t total_mb;
float free_percent;
};

// Helper function to get current memory information
MemoryInfo getMemoryInfo()
{
size_t free_mem = 0, total_mem = 0;
TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem));

size_t const free_mb = free_mem / (1024 * 1024);
size_t const total_mb = total_mem / (1024 * 1024);
float const free_percent = (total_mem > 0) ? (static_cast<float>(free_mem) / total_mem * 100.0f) : 0.0f;

return {free_mb, total_mb, free_percent};
}

// Helper function to log current memory usage
void logMemoryUsage(char const* operation, CUcontext ctx)
{
auto const mem = getMemoryInfo();
TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, mem.free_mb,
mem.free_percent, mem.total_mb);
}

// Helper function to throw
void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status)
{
auto const mem = getMemoryInfo();
TLLM_THROW(
"Failed to create %s. "
"Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. "
"Consider reducing kv_cache_config.free_gpu_memory_fraction.",
operation, status, ctx, mem.free_mb, mem.free_percent, mem.total_mb);
}

} // namespace

std::shared_ptr<cublasHandle_t> getCublasHandle()
{
static PerCudaCtxPerThreadSingletonCreator<cublasHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasHandle_t>(new cublasHandle_t);
TLLM_CUDA_CHECK(cublasCreate(handle.get()));
CUcontext ctx = getCurrentCudaCtx();
logMemoryUsage("Creating cublas handle", ctx);

auto handle = std::make_unique<cublasHandle_t>();
auto status = cublasCreate(handle.get());

if (status != CUBLAS_STATUS_SUCCESS)
{
throwCublasErrorWithMemInfo("cublas handle", ctx, status);
}

return handle;
},
[](cublasHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasDestroy(*handle));
auto status = cublasDestroy(*handle);
if (status != CUBLAS_STATUS_SUCCESS)
{
TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status);
}
delete handle;
handle = nullptr;
});
return creator();
}
Expand All @@ -262,14 +330,28 @@ std::shared_ptr<cublasLtHandle_t> getCublasLtHandle()
static PerCudaCtxPerThreadSingletonCreator<cublasLtHandle_t> creator(
[]() -> auto
{
auto handle = std::unique_ptr<cublasLtHandle_t>(new cublasLtHandle_t);
TLLM_CUDA_CHECK(cublasLtCreate(handle.get()));
CUcontext ctx = getCurrentCudaCtx();
logMemoryUsage("Creating cublasLt handle", ctx);

auto handle = std::make_unique<cublasLtHandle_t>();
auto status = cublasLtCreate(handle.get());

if (status != CUBLAS_STATUS_SUCCESS)
{
throwCublasErrorWithMemInfo("cublasLt handle", ctx, status);
}

return handle;
},
[](cublasLtHandle_t* handle)
{
TLLM_CUDA_CHECK(cublasLtDestroy(*handle));
auto status = cublasLtDestroy(*handle);
if (status != CUBLAS_STATUS_SUCCESS)
{
TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status);
}
delete handle;
handle = nullptr;
});
return creator();
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,13 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
{
if (sm == 89 || sm >= 120)
{
return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128,
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64};
CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64,
CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128};
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ FmhaDispatcher::FmhaDispatcher(MHARunnerFixedParams fixedParams)
// TRTLLM-GEN only supports power of 2 head sizes.
// The exception will fall back to fmha v2.
// Please update fmha_v2/setup.py if you want to add more supported head sizes.
, mUseTllmGen(tensorrt_llm::common::isSM100Family() && fixedParams.headSize != 80)
, mUseTllmGen(tensorrt_llm::common::isSM100Family() && fixedParams.headSize != 80 && fixedParams.headSize != 72)
{
if (mUseTllmGen)
{
Expand Down
2 changes: 1 addition & 1 deletion docs/source/blogs/H100vsA100.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ TensorRT LLM evaluated on both Hopper and Ampere shows **H100 FP8 is up to 4.6x

<sub>FP8 H100, FP16 A100, SXM 80GB GPUs, TP1, ISL/OSL's provided, TensorRT LLM v0.5.0., TensorRT 9.1</sub>

The full data behind these charts & tables and including larger models with higher TP values can be found in TensorRT LLM's [Performance Documentation](https://nvidia.github.io/TensorRT-LLM/latest/performance/perf-overview.html)
The full data behind these charts & tables and including larger models with higher TP values can be found in TensorRT LLM's [Performance Documentation](https://nvidia.github.io/TensorRT-LLM/0.21.0/performance/perf-overview.html)

Stay tuned for a highlight on Llama coming soon!

Expand Down
2 changes: 1 addition & 1 deletion docs/source/blogs/H200launch.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ TensorRT LLM evaluation of the [new H200 GPU](https://nvidianews.nvidia.com/news

<sup>*(1) Largest batch supported on given TP configuration by power of 2.*</sup> <sup>*(2) TP = Tensor Parallelism*</sup>

Additional Performance data is available on the [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference/ai-inference) page, & soon in [TensorRT LLM's Performance Documentation](https://nvidia.github.io/TensorRT-LLM/latest/performance/perf-overview.html).
Additional Performance data is available on the [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference/ai-inference) page, & soon in [TensorRT LLM's Performance Documentation](https://nvidia.github.io/TensorRT-LLM/0.21.0/performance/perf-overview.html).

### H200 vs H100

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ In the Dynamo workflow, requests are initially processed by pre- and post-proces

Dynamo also includes built-in support for Kubernetes deployment, monitoring, and metrics collection. The development team is actively working on enabling dynamic instance scaling, further enhancing its suitability for production environments.

For more information on how to use Dynamo with TensorRT LLM, please refer to [this documentation](https://docs.nvidia.com/dynamo/latest/examples/trtllm.html).
For more information on how to use Dynamo with TensorRT LLM, please refer to [this documentation](https://docs.nvidia.com/dynamo/latest/backends/trtllm/README.html).

### Triton Inference Server

Expand Down
2 changes: 1 addition & 1 deletion docs/source/features/disagg-serving.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ In the Dynamo workflow, requests are initially processed by pre- and post-proces

Dynamo also includes built-in support for Kubernetes deployment, monitoring, and metrics collection. The development team is actively working on enabling dynamic instance scaling, further enhancing its suitability for production environments.

For more information on how to use Dynamo with TensorRT-LLM, please refer to [this documentation](https://docs.nvidia.com/dynamo/latest/examples/trtllm.html).
For more information on how to use Dynamo with TensorRT-LLM, please refer to [this documentation](https://docs.nvidia.com/dynamo/latest/backends/trtllm/README.html).

### trtllm-serve

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Welcome to TensorRT LLM's Documentation!
features/ray-orchestrator.md
features/torch_compile_and_piecewise_cuda_graph.md


.. toctree::
:maxdepth: 2
:caption: Developer Guide
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
| VILA | Yes | No | No | No |
| LLaVA-NeXT | Yes | Yes | Yes | Yes |
| Llama 4 | Yes | Yes | No | No |
| Mistral-Small-3.1 | Yes | Yes | No | No |
| Phi-4-multimodal | Yes | Yes | No | No |
| Mistral-Small-3.1 | Yes | Yes | Yes | Yes |
| Phi-4-multimodal | Yes | Yes | Yes | Yes |
| Qwen2-VL | Yes | Yes | Yes | Yes |
| Qwen2.5-VL | Yes | Yes | Yes | Yes |
8 changes: 4 additions & 4 deletions docs/source/models/supported-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ Note: Support for other models may vary. Features marked "N/A" are not applicabl
| `Gemma3ForConditionalGeneration` | Yes | Yes | N/A | Yes | Yes | N/A | Yes | No | L + I |
| `HCXVisionForCausalLM` | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I |
| `LlavaLlamaModel (VILA)` | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I + V |
| `LlavaNextForConditionalGeneration` | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| `LlavaNextForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | L + I |
| `Llama4ForConditionalGeneration` | Yes | Yes | No | Yes | Yes | No | Yes | No | L + I |
| `Mistral3ForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I |
| `NemotronH_Nano_VL_V2` | Yes | Yes | Yes | Yes | Yes | No | Yes | No | L + I + V |
| `NemotronH_Nano_VL_V2` | Yes | Yes | Yes | Yes | Yes | N/A | Yes | No | L + I + V |
| `Phi4MMForCausalLM` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I + A |
| `Qwen2VLForConditionalGeneration` | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
| `Qwen2_5_VLForConditionalGeneration` | Yes | Yes | No | Yes | Yes | Yes | Yes | No | L + I + V |
| `Qwen2VLForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I + V |
| `Qwen2_5_VLForConditionalGeneration` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | L + I + V |

Note:
- L: Language
Expand Down
2 changes: 1 addition & 1 deletion docs/source/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ TensorRT LLM delivers breakthrough performance on the latest NVIDIA GPUs:

### 🎯 **Comprehensive Model Support**

TensorRT LLM supports the latest and most popular LLM architectures:
TensorRT LLM supports the latest and most popular LLM [architectures](https://nvidia.github.io/TensorRT-LLM/models/supported-models.html).

- **Language Models**: GPT-OSS, Deepseek-R1/V3, Llama 3/4, Qwen2/3, Gemma 3, Phi 4...
- **Multi-modal Models**: LLaVA-NeXT, Qwen2-VL, VILA, Llama 3.2 Vision...
Expand Down
7 changes: 7 additions & 0 deletions docs/source/quick-start-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ To start the server, you can run a command like the following example inside a D
trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
```

You may also deploy pre-quantized models to improve performance.
Ensure your GPU supports FP8 quantization before running the following:

```bash
trtllm-serve "nvidia/Qwen3-8B-FP8"
```

```{note}
If you are running `trtllm-serve` inside a Docker container, you have two options for sending API requests:
1. Expose a port (e.g., 8000) to allow external access to the server from outside the container.
Expand Down
5 changes: 5 additions & 0 deletions examples/llm-api/extra-llm-api-config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cuda_graph_config:
enable_padding: True
max_batch_size: 16
moe_config:
backend: trtllm
Loading