Skip to content
Merged
58 changes: 51 additions & 7 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,49 @@ def get_bindings_model_config(self,

num_heads = self.pretrained_config.num_attention_heads // (
self.mapping.tp_size * self.mapping.cp_size)

# Handle both uniform and per-layer KV heads
num_kv_heads_per_layer = getattr(self.pretrained_config,
'num_kv_heads_per_layer', None)
if num_kv_heads_per_layer is not None:
# For models with per-layer KV heads, like nemotron-nas
kv_heads_per_layer_raw = num_kv_heads_per_layer
use_per_layer_kv_heads = True
else:
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
num_kv_heads_raw = getattr(self.pretrained_config,
'num_key_value_heads', None)

if num_kv_heads_raw is not None and isinstance(
num_kv_heads_raw, list):
# num_key_value_heads is a list - treat as per-layer KV heads
kv_heads_per_layer_raw = num_kv_heads_raw
use_per_layer_kv_heads = True
else:
# num_key_value_heads is scalar or None - treat as uniform KV heads
if num_kv_heads_raw is None:
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
num_kv_heads_raw = getattr(
self.pretrained_config, 'num_query_groups',
self.pretrained_config.num_attention_heads)

num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
self.mapping.cp_size)
use_per_layer_kv_heads = False

if use_per_layer_kv_heads:
# TRT-LLM LoRA requires uniform KV heads across layers
if self.lora_config is not None and len(
set(kv_heads_per_layer_raw)) > 1:
raise ValueError(
f"TRT-LLM LoRA requires uniform KV heads across layers, "
f"got: {kv_heads_per_layer_raw}")
# Apply TP/CP scaling to each layer
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
for kv_heads in kv_heads_per_layer_raw
]

hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size

model_config_cpp = ModelConfigCpp(
Expand All @@ -317,11 +360,10 @@ def get_bindings_model_config(self,
else:
model_config_cpp.tokens_per_block = tokens_per_block

# For kv cache size calculation: set num_kv_heads
num_kv_heads = getattr(
self.pretrained_config, "num_key_value_heads",
num_heads) // (self.mapping.tp_size * self.mapping.cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)
if use_per_layer_kv_heads:
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
else:
model_config_cpp.set_num_kv_heads(num_kv_heads)

mlp_hidden_size = None
if self.pretrained_config.intermediate_size is not None:
Expand Down Expand Up @@ -371,8 +413,10 @@ def _infer_nemotron_ffn_mult(self):
# Nemotron-NAS has variable ffn_mult for each layer, we need to find the maximum
# so that we don't set a too small mlp_hidden_size. This solution leads to a memory
# consumption that is higher than required.
biggest_ffn_mult = max(
[x.ffn.ffn_mult for x in self.pretrained_config.block_configs])
biggest_ffn_mult = max([
(x.ffn.ffn_mult if x.ffn.ffn_mult is not None else 0)
for x in self.pretrained_config.block_configs
])

from tensorrt_llm._torch.models.modeling_nemotron_nas import \
_ffn_mult_to_intermediate_size
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,11 +703,13 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
model_config,
'lora_config') and model_config.lora_config is not None and len(
model_config.lora_config.lora_dir) == 1:
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True
# Only check for custom vocab in HF LoRA, not NeMo
if model_config.lora_config.lora_ckpt_source == "hf":
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True

if self.model_config.mapping.enable_attention_dp:
self.embed_tokens = Embedding(
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@ def __init__(self, model_config):
model_config,
'lora_config') and model_config.lora_config is not None and len(
model_config.lora_config.lora_dir) == 1:
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True
# Only check for custom vocab in HF LoRA, not NeMo
if model_config.lora_config.lora_ckpt_source == "hf":
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True

self.embed_tokens = Embedding(
vocab_size,
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,13 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
if (hasattr(config, 'lora_config')
and config.lora_config is not None
and len(config.lora_config.lora_dir) == 1):
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
if lora_loader.lm_head is not None and lora_loader.vocab_size != 0:
weight = lora_loader.lm_head
self.has_custom_lm_head = True
vocab_size = lora_loader.vocab_size
# Only check for custom lm_head in HF LoRA, not NeMo
if config.lora_config.lora_ckpt_source == "hf":
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
if lora_loader.lm_head is not None and lora_loader.vocab_size != 0:
weight = lora_loader.lm_head
self.has_custom_lm_head = True
vocab_size = lora_loader.vocab_size

self.lm_head = LMHead(
vocab_size,
Expand Down
20 changes: 17 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
load_torch_hf_lora)
load_torch_lora)
from tensorrt_llm.mapping import Mapping

from ..model_config import ModelConfig
Expand Down Expand Up @@ -437,7 +437,8 @@ def create_py_executor_instance(
from tensorrt_llm.bindings import LoraModule

if len(lora_config.lora_dir) == 1:
load_torch_hf_lora(lora_config)
# Route to appropriate loader based on checkpoint source
load_torch_lora(lora_config)
else:
assert len(lora_config.lora_target_modules
) >= 1, "Expecting at least one lora target module"
Expand All @@ -450,12 +451,25 @@ def create_py_executor_instance(

num_experts = _try_infer_num_experts(model_engine.model.model_config)

num_attn_layers = model_binding_config.num_attention_layers()
per_layer_kv_heads = [
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
]
num_kv_attention_heads = max(per_layer_kv_heads)
if len(set(per_layer_kv_heads)) > 1:
# NOTE: This code-path is currently untested and not validated. Can fail!
# This support is tracked in TRTLLM-6561
logger.warning(
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
"This code-path is currently untested and not validated. May fail!"
)

lora_modules = LoraModule.create_lora_modules(
lora_module_names=lora_config.lora_target_modules,
hidden_size=model_binding_config.hidden_size,
mlp_hidden_size=model_binding_config.mlp_hidden_size,
num_attention_heads=model_binding_config.num_heads,
num_kv_attention_heads=model_binding_config.num_heads,
num_kv_attention_heads=num_kv_attention_heads,
attention_head_size=model_binding_config.head_size,
tp_size=mapping.tp_size,
num_experts=num_experts)
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/executor/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ class LoRARequest:
lora_name: str
lora_int_id: int
lora_path: str = ""
lora_ckpt_source: str = "hf"

def __post_init__(self):
if self.lora_path is not None and not os.path.exists(self.lora_path):
raise ValueError(f"lora_path ({self.lora_path}) does not exist.")
if self.lora_ckpt_source not in ["hf", "nemo"]:
raise ValueError(
f"lora_ckpt_source must be 'hf' or 'nemo', got '{self.lora_ckpt_source}'"
)

@property
def adapter_id(self):
Expand All @@ -42,6 +47,10 @@ def name(self):
def path(self):
return self.lora_path

@property
def ckpt_source(self):
return self.lora_ckpt_source


@dataclass(slots=True)
class PromptAdapterRequest:
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
model_config=self._runtime_model_config if
self._runtime_model_config is not None else self._lora_model_config,
runtime_mapping=None,
uids=[adapter_id])
uids=[adapter_id],
ckpt_source=lora_request.ckpt_source)
return adapter_id in newly_loaded_uids

def _load_prompt_adapter(self,
Expand Down
Loading