Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,20 @@ def get_bindings_model_config(self,

mlp_hidden_size = None
if self.pretrained_config.intermediate_size is not None:
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
if isinstance(self.pretrained_config.intermediate_size,
(list, tuple)):
# Per-layer MLP dimensions (e.g., Nemotron-NAS, variable MLP models)
mlp_hidden_size_per_layer = [
intermediate_size // self.mapping.tp_size
for intermediate_size in
self.pretrained_config.intermediate_size
]
model_config_cpp.mlp_hidden_size_per_layer = mlp_hidden_size_per_layer
# For LoRA compatibility, use the maximum MLP dimension
mlp_hidden_size = max(mlp_hidden_size_per_layer)
else:
# Uniform MLP dimensions across all layers
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
else:
# TODO: once tensorrt_llm._torch.AutoConfig is implemented, the following logic
# should be moved to tensorrt_llm._torch.AutoConfig of the relevant modeling_xxx file
Expand Down
14 changes: 13 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,22 @@ def create_py_executor_instance(
# all layers have the same number of KV heads
num_kv_attention_heads = num_kv_attention_heads_per_layer[0]

mlp_hidden_size_per_layer = model_binding_config.mlp_hidden_size_per_layer
if mlp_hidden_size_per_layer and max(mlp_hidden_size_per_layer) != min(
mlp_hidden_size_per_layer):
logger.warning(
"Defining LORA with per-layer MLP dimensions is not supported for LORA, using the max MLP hidden size per layer"
)
mlp_hidden_size = max(mlp_hidden_size_per_layer)
else:
# all layers have the same MLP hidden size
mlp_hidden_size = mlp_hidden_size_per_layer[0]

# THEN UPDATE THE LoraModule.create_lora_modules CALL:
lora_modules = LoraModule.create_lora_modules(
lora_module_names=lora_config.lora_target_modules,
hidden_size=model_binding_config.hidden_size,
mlp_hidden_size=model_binding_config.mlp_hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=model_binding_config.num_heads,
num_kv_attention_heads=num_kv_attention_heads,
attention_head_size=model_binding_config.head_size,
Expand Down
74 changes: 74 additions & 0 deletions tests/integration/defs/examples/test_nemotron_nas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from pathlib import Path

import defs.ci_profiler
import pytest
from defs.common import convert_weights, venv_check_call, venv_mpi_check_call
from defs.conftest import get_device_memory, get_sm_version
from defs.trt_test_alternative import check_call

from tensorrt_llm import LLM
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.sampling_params import SamplingParams

# skip trt flow cases on post-Blackwell-Ultra
if get_sm_version() >= 103:
pytest.skip(
Expand Down Expand Up @@ -122,3 +128,71 @@ def test_nemotron_nas_summary_2gpu(nemotron_nas_example_root, llm_venv,
]

venv_mpi_check_call(llm_venv, mpi_cmd, summary_cmd)


@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("nemotron_nas_model_root", [
"Llama-3_3-Nemotron-Super-49B-v1",
],
indirect=True)
def test_nemotron_super_49b_real_lora_torch(nemotron_nas_example_root, llm_venv,
nemotron_nas_model_root,
llm_datasets_root, llm_rouge_root,
engine_dir, cmodel_dir):
"""Run Nemotron Super 49B with real LoRA adapters using LLM-API Torch backend."""

print("Testing Nemotron Super 49B with real LoRA adapters...")

lora_adapter_path = f"/code/tensorrt_llm/llama-3.3-nemotron-super-49b-v1/llama-3.3-nemotron-super-49b-v1_vlora-1a2cb80-v2"
print(f"Using real LoRA from: {lora_adapter_path}")

defs.ci_profiler.start("test_nemotron_real_lora_torch")

lora_config = LoraConfig(
lora_dir=[lora_adapter_path],
max_lora_rank=32, # From adapter_config.json: "r": 32
max_loras=1,
max_cpu_loras=1,
)

with LLM(model=nemotron_nas_model_root,
lora_config=lora_config,
tensor_parallel_size=4,
dtype="bfloat16",
max_batch_size=2,
max_input_len=512,
max_seq_len=1024,
max_beam_width=1) as llm:

prompts = [
"What is the capital of France?",
"Explain quantum computing in simple terms."
]

sampling_params = SamplingParams(max_tokens=50,
temperature=0.7,
top_p=0.9)

lora_request = [LoRARequest("nemotron-lora", 0, lora_adapter_path)]

print("Running inference with real LoRA adapter...")
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_request)

for i, output in enumerate(outputs):
print(f"Prompt {i+1}: {prompts[i]}")
print(f"Response {i+1}: {output.outputs[0].text}")
print("-" * 50)

assert len(outputs) == 2
assert len(outputs[0].outputs) > 0
assert len(outputs[1].outputs) > 0
assert len(outputs[0].outputs[0].text) > 0
assert len(outputs[1].outputs[0].text) > 0

defs.ci_profiler.stop("test_nemotron_real_lora_torch")
print(
f"test_nemotron_real_lora_torch: {defs.ci_profiler.elapsed_time_in_sec('test_nemotron_real_lora_torch')} sec"
)
Loading