Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def add_llm_args(parser):
default=False,
action='store_true',
help='Use piecewise CUDA graph to optimize the model')
parser.add_argument('--apply_chat_template',
default=False,
action='store_true')

# Sampling
parser.add_argument("--max_tokens", type=int, default=64)
Expand Down Expand Up @@ -273,6 +276,15 @@ def main():
prompts = args.prompt if args.prompt else example_prompts

llm, sampling_params = setup_llm(args)
new_prompts = []
if args.apply_chat_template:
for prompt in prompts:
messages = [{"role": "user", "content": f"{prompt}"}]
new_prompts.append(
llm.tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True))
prompts = new_prompts
outputs = llm.generate(prompts, sampling_params)
Comment on lines +279 to 288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Guard against tokenizers without a chat template and reduce scope of temporaries

Without a guard, enabling the flag on models lacking a chat template will raise and crash the example. Also, move new_prompts into the if-block to reduce scope.

You can verify this by running with a tokenizer that lacks a chat template; the current code will raise on apply_chat_template.

-    new_prompts = []
-    if args.apply_chat_template:
-        for prompt in prompts:
-            messages = [{"role": "user", "content": f"{prompt}"}]
-            new_prompts.append(
-                llm.tokenizer.apply_chat_template(messages,
-                                                  tokenize=False,
-                                                  add_generation_prompt=True))
-        prompts = new_prompts
+    if args.apply_chat_template:
+        try:
+            # Validate template availability (will raise if unsupported)
+            _ = llm.tokenizer.get_chat_template()
+            new_prompts = []
+            for prompt in prompts:
+                messages = [{"role": "user", "content": prompt}]
+                new_prompts.append(
+                    llm.tokenizer.apply_chat_template(
+                        messages, tokenize=False, add_generation_prompt=True
+                    )
+                )
+            prompts = new_prompts
+        except Exception as e:
+            print(f"[warn] --apply_chat_template ignored: {e}")

If you prefer structured logging instead of print, add this import at the top of the file and replace the print with logger.warning(...):

from tensorrt_llm.logger import logger

🏁 Script executed:

#!/bin/bash
set -e

echo "Searching for apply_chat_template definitions and references:"
rg "apply_chat_template" -n

echo
echo "Searching for get_chat_template definitions and references:"
rg "get_chat_template" -n

Length of output: 10606


🏁 Script executed:

#!/bin/bash
set -e

echo "Inspecting tensorrt_llm/llmapi/tokenizer.py around get/apply_chat_template definitions:"
rg -n "def get_chat_template" -C3 tensorrt_llm/llmapi/tokenizer.py
rg -n "def apply_chat_template" -C3 tensorrt_llm/llmapi/tokenizer.py

echo
echo "Inspecting tensorrt_llm/inputs/utils.py around get/apply_chat_template definitions:"
rg -n "def get_chat_template" -C3 tensorrt_llm/inputs/utils.py
rg -n "def apply_chat_template" -C3 tensorrt_llm/inputs/utils.py

Length of output: 1389


🏁 Script executed:

#!/bin/bash
set -e

echo "Extracting full `apply_chat_template` from `tensorrt_llm/llmapi/tokenizer.py`:"
rg -n "def apply_chat_template" -C10 tensorrt_llm/llmapi/tokenizer.py

echo
echo "Extracting `get_chat_template` and `apply_chat_template` from `tensorrt_llm/inputs/utils.py`:"
rg -n "def get_chat_template" -C10 tensorrt_llm/inputs/utils.py
rg -n "def apply_chat_template" -C10 tensorrt_llm/inputs/utils.py

Length of output: 1986


🏁 Script executed:

#!/bin/bash
set -e

echo "Showing `get_chat_template` definition in tensorrt_llm/inputs/utils.py:"
rg -n 'def get_chat_template' -C5 tensorrt_llm/inputs/utils.py

echo
echo "Showing `apply_chat_template` definition in tensorrt_llm/inputs/utils.py:"
rg -n 'def apply_chat_template' -C5 tensorrt_llm/inputs/utils.py

Length of output: 317


🏁 Script executed:

#!/bin/bash
set -e

echo "Locating TokenizerBase and its chat-template methods in inputs/utils.py"
rg -n "class TokenizerBase" -C5 tensorrt_llm/inputs/utils.py

echo
echo "Definition of get_chat_template in inputs/utils.py:"
rg -n "def get_chat_template" -C10 tensorrt_llm/inputs/utils.py

echo
echo "Definition of apply_chat_template in inputs/utils.py:"
rg -n "def apply_chat_template" -C10 tensorrt_llm/inputs/utils.py

Length of output: 224


#!/bin/bash
set -e

echo "Displaying lines 400-460 from tensorrt_llm/inputs/utils.py:"
sed -n '400,460p' tensorrt_llm/inputs/utils.py

Guard chat-template usage and limit new_prompts scope in quickstart_advanced.py

Move the declaration of new_prompts into the if block and wrap chat‐template calls in a try/except to skip unsupported tokenizers instead of crashing:

     if args.apply_chat_template:
-        # new_prompts declared too early
-        new_prompts = []
+        try:
+            # Ensure this tokenizer actually supports chat templates
+            _ = llm.tokenizer.get_chat_template()
+            new_prompts = []
             for prompt in prompts:
                 messages = [{"role": "user", "content": prompt}]
                 new_prompts.append(
                     llm.tokenizer.apply_chat_template(
                         messages,
                         tokenize=False,
                         add_generation_prompt=True,
                     )
                 )
             prompts = new_prompts
-        # no handling if unsupported → crash
+        except Exception as e:
+            # Skip templating on unsupported tokenizers
+            # (or replace print with logger.warning if preferred)
+            print(f"[warn] --apply_chat_template ignored: {e}")
     outputs = llm.generate(prompts, sampling_params)

If you’d rather use structured logging, add at the top:

from tensorrt_llm.logger import logger

and replace print(...) with:

logger.warning(f"--apply_chat_template ignored: {e}")
🤖 Prompt for AI Agents
In examples/llm-api/quickstart_advanced.py around lines 279 to 288, the variable
new_prompts is declared outside the chat-template guard and chat-template calls
can raise for tokenizers that don't support it; move the declaration of
new_prompts inside the if args.apply_chat_template block so its scope is
limited, and wrap the per-prompt llm.tokenizer.apply_chat_template call in a
try/except that catches the tokenizer error and continues (skipping that prompt
or appending the original prompt) instead of letting the script crash;
optionally import the package logger at the top and replace any print warnings
with logger.warning messages.


for i, output in enumerate(outputs):
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def __call__(self, position_ids: torch.Tensor, q: torch.Tensor,
class RopeParams:
dim: int = 0
theta: float = 10000.0
alpha: float = 1.0
scale_type: RotaryScalingType = RotaryScalingType.none
scale: float = 1.0
low_freq_factor: float = 1.0
Expand Down Expand Up @@ -384,6 +385,7 @@ def from_config(config) -> "RopeParams":
rope_params.scale_type = RotaryScalingType.none
rope_params.scale = 1.0
if rope_scaling is not None:
rope_params.alpha = rope_scaling.get("alpha", 1.0)
rotary_scaling_type = rope_scaling.get(
"type", None) or rope_scaling.get("rope_type")
rope_params.scale_type = RotaryScalingType.from_string(
Expand Down Expand Up @@ -462,6 +464,7 @@ def create_rope_const_params(self, interleave: bool = True):
self.scale_type,
rope_scaling_config={
"factor": self.scale,
"alpha": self.alpha,
"low_freq_factor": self.low_freq_factor,
"high_freq_factor": self.high_freq_factor,
"original_max_position_embeddings":
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ def fuse_pos_embd(self):
@property
def enable_flash_mla(self):
if self.attn_backend == 'TRTLLM':
if hasattr(self.pretrained_config, "kv_lora_rank") and hasattr(
self.pretrained_config, "qk_rope_head_dim"):
if getattr(self.pretrained_config,
"kv_lora_rank", None) and getattr(
self.pretrained_config, "qk_rope_head_dim", None):
head_dim = self.pretrained_config.kv_lora_rank + self.pretrained_config.qk_rope_head_dim
if head_dim == 576 and torch.cuda.get_device_capability() == (
9, 0):
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .modeling_gemma3 import Gemma3ForCausalLM
from .modeling_gemma3vl import Gemma3VLM
from .modeling_gpt_oss import GptOssForCausalLM
from .modeling_hunyuan_moe import HunYuanMoEV1ForCausalLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing import for Gemma3Model while exporting it in all

"Gemma3Model" is added to all but is not imported, which will break direct imports like from ...models import Gemma3Model. Import it from modeling_gemma3.

Apply this diff:

-from .modeling_gemma3 import Gemma3ForCausalLM
+from .modeling_gemma3 import Gemma3ForCausalLM, Gemma3Model

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/__init__.py around line 11, add an import for
Gemma3Model so the symbol exported in __all__ is actually available;
specifically import Gemma3Model from .modeling_gemma3 (e.g. add the line `from
.modeling_gemma3 import Gemma3Model`) so `from ...models import Gemma3Model`
works as expected.

from .modeling_hyperclovax import HCXVisionForCausalLM
from .modeling_llama import LlamaForCausalLM
from .modeling_llava_next import LlavaNextModel
Expand Down Expand Up @@ -38,6 +39,7 @@
"Gemma3ForCausalLM",
"Gemma3VLM",
"HCXVisionForCausalLM",
"HunYuanMoEV1ForCausalLM",
"LlamaForCausalLM",
"LlavaNextModel",
"Mistral3VLM",
Expand Down
Loading