Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,49 @@ def get_bindings_model_config(self,

num_heads = self.pretrained_config.num_attention_heads // (
self.mapping.tp_size * self.mapping.cp_size)

# Handle both uniform and per-layer KV heads
num_kv_heads_per_layer = getattr(self.pretrained_config,
'num_kv_heads_per_layer', None)
if num_kv_heads_per_layer is not None:
# For models with per-layer KV heads, like nemotron-nas
kv_heads_per_layer_raw = num_kv_heads_per_layer
use_per_layer_kv_heads = True
else:
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
num_kv_heads_raw = getattr(self.pretrained_config,
'num_key_value_heads', None)

if num_kv_heads_raw is not None and isinstance(
num_kv_heads_raw, list):
# num_key_value_heads is a list - treat as per-layer KV heads
kv_heads_per_layer_raw = num_kv_heads_raw
use_per_layer_kv_heads = True
else:
# num_key_value_heads is scalar or None - treat as uniform KV heads
if num_kv_heads_raw is None:
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
num_kv_heads_raw = getattr(
self.pretrained_config, 'num_query_groups',
self.pretrained_config.num_attention_heads)

num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
self.mapping.cp_size)
use_per_layer_kv_heads = False

if use_per_layer_kv_heads:
# TRT-LLM LoRA requires uniform KV heads across layers
if self.lora_config is not None and len(
set(kv_heads_per_layer_raw)) > 1:
raise ValueError(
f"TRT-LLM LoRA requires uniform KV heads across layers, "
f"got: {kv_heads_per_layer_raw}")
# Apply TP/CP scaling to each layer
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
for kv_heads in kv_heads_per_layer_raw
]

hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size

model_config_cpp = ModelConfigCpp(
Expand All @@ -317,11 +360,10 @@ def get_bindings_model_config(self,
else:
model_config_cpp.tokens_per_block = tokens_per_block

# For kv cache size calculation: set num_kv_heads
num_kv_heads = getattr(
self.pretrained_config, "num_key_value_heads",
num_heads) // (self.mapping.tp_size * self.mapping.cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)
if use_per_layer_kv_heads:
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
else:
model_config_cpp.set_num_kv_heads(num_kv_heads)

mlp_hidden_size = None
if self.pretrained_config.intermediate_size is not None:
Expand Down Expand Up @@ -371,8 +413,10 @@ def _infer_nemotron_ffn_mult(self):
# Nemotron-NAS has variable ffn_mult for each layer, we need to find the maximum
# so that we don't set a too small mlp_hidden_size. This solution leads to a memory
# consumption that is higher than required.
biggest_ffn_mult = max(
[x.ffn.ffn_mult for x in self.pretrained_config.block_configs])
biggest_ffn_mult = max([
(x.ffn.ffn_mult if x.ffn.ffn_mult is not None else 0)
for x in self.pretrained_config.block_configs
])

from tensorrt_llm._torch.models.modeling_nemotron_nas import \
_ffn_mult_to_intermediate_size
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,11 +703,13 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
model_config,
'lora_config') and model_config.lora_config is not None and len(
model_config.lora_config.lora_dir) == 1:
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True
# Only check for custom vocab in HF LoRA, not NeMo
if model_config.lora_config.lora_ckpt_source == "hf":
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True

if self.model_config.mapping.enable_attention_dp:
self.embed_tokens = Embedding(
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,13 @@ def __init__(self, model_config):
model_config,
'lora_config') and model_config.lora_config is not None and len(
model_config.lora_config.lora_dir) == 1:
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True
# Only check for custom vocab in HF LoRA, not NeMo
if model_config.lora_config.lora_ckpt_source == "hf":
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None:
vocab_size = lora_loader.vocab_size
weight = lora_loader.embed_tokens
self.has_custom_embed_tokens = True

self.embed_tokens = Embedding(
vocab_size,
Expand Down
12 changes: 7 additions & 5 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,13 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
if (hasattr(config, 'lora_config')
and config.lora_config is not None
and len(config.lora_config.lora_dir) == 1):
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
if lora_loader.lm_head is not None and lora_loader.vocab_size != 0:
weight = lora_loader.lm_head
self.has_custom_lm_head = True
vocab_size = lora_loader.vocab_size
# Only check for custom lm_head in HF LoRA, not NeMo
if config.lora_config.lora_ckpt_source == "hf":
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
if lora_loader.lm_head is not None and lora_loader.vocab_size != 0:
weight = lora_loader.lm_head
self.has_custom_lm_head = True
vocab_size = lora_loader.vocab_size

self.lm_head = LMHead(
vocab_size,
Expand Down
20 changes: 17 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
load_torch_hf_lora)
load_torch_lora)
from tensorrt_llm.mapping import Mapping

from ..model_config import ModelConfig
Expand Down Expand Up @@ -437,7 +437,8 @@ def create_py_executor_instance(
from tensorrt_llm.bindings import LoraModule

if len(lora_config.lora_dir) == 1:
load_torch_hf_lora(lora_config)
# Route to appropriate loader based on checkpoint source
load_torch_lora(lora_config)
else:
assert len(lora_config.lora_target_modules
) >= 1, "Expecting at least one lora target module"
Expand All @@ -450,12 +451,25 @@ def create_py_executor_instance(

num_experts = _try_infer_num_experts(model_engine.model.model_config)

num_attn_layers = model_binding_config.num_attention_layers()
per_layer_kv_heads = [
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
]
num_kv_attention_heads = max(per_layer_kv_heads)
if len(set(per_layer_kv_heads)) > 1:
# NOTE: This code-path is currently untested and not validated. Can fail!
# This support is tracked in TRTLLM-6561
logger.warning(
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
"This code-path is currently untested and not validated. May fail!"
)

lora_modules = LoraModule.create_lora_modules(
lora_module_names=lora_config.lora_target_modules,
hidden_size=model_binding_config.hidden_size,
mlp_hidden_size=model_binding_config.mlp_hidden_size,
num_attention_heads=model_binding_config.num_heads,
num_kv_attention_heads=model_binding_config.num_heads,
num_kv_attention_heads=num_kv_attention_heads,
attention_head_size=model_binding_config.head_size,
tp_size=mapping.tp_size,
num_experts=num_experts)
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/executor/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ class LoRARequest:
lora_name: str
lora_int_id: int
lora_path: str = ""
lora_ckpt_source: str = "hf"

def __post_init__(self):
if self.lora_path is not None and not os.path.exists(self.lora_path):
raise ValueError(f"lora_path ({self.lora_path}) does not exist.")
if self.lora_ckpt_source not in ["hf", "nemo"]:
raise ValueError(
f"lora_ckpt_source must be 'hf' or 'nemo', got '{self.lora_ckpt_source}'"
)

@property
def adapter_id(self):
Expand All @@ -42,6 +47,10 @@ def name(self):
def path(self):
return self.lora_path

@property
def ckpt_source(self):
return self.lora_ckpt_source


@dataclass(slots=True)
class PromptAdapterRequest:
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
model_config=self._runtime_model_config if
self._runtime_model_config is not None else self._lora_model_config,
runtime_mapping=None,
uids=[adapter_id])
uids=[adapter_id],
ckpt_source=lora_request.ckpt_source)
return adapter_id in newly_loaded_uids

def _load_prompt_adapter(self,
Expand Down
Loading