diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 3de3edd3a9b..3d0175a3c23 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -297,6 +297,49 @@ def get_bindings_model_config(self, num_heads = self.pretrained_config.num_attention_heads // ( self.mapping.tp_size * self.mapping.cp_size) + + # Handle both uniform and per-layer KV heads + num_kv_heads_per_layer = getattr(self.pretrained_config, + 'num_kv_heads_per_layer', None) + if num_kv_heads_per_layer is not None: + # For models with per-layer KV heads, like nemotron-nas + kv_heads_per_layer_raw = num_kv_heads_per_layer + use_per_layer_kv_heads = True + else: + # Check if num_key_value_heads is a list (per-layer) or scalar (uniform) + num_kv_heads_raw = getattr(self.pretrained_config, + 'num_key_value_heads', None) + + if num_kv_heads_raw is not None and isinstance( + num_kv_heads_raw, list): + # num_key_value_heads is a list - treat as per-layer KV heads + kv_heads_per_layer_raw = num_kv_heads_raw + use_per_layer_kv_heads = True + else: + # num_key_value_heads is scalar or None - treat as uniform KV heads + if num_kv_heads_raw is None: + # For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads + num_kv_heads_raw = getattr( + self.pretrained_config, 'num_query_groups', + self.pretrained_config.num_attention_heads) + + num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size * + self.mapping.cp_size) + use_per_layer_kv_heads = False + + if use_per_layer_kv_heads: + # TRT-LLM LoRA requires uniform KV heads across layers + if self.lora_config is not None and len( + set(kv_heads_per_layer_raw)) > 1: + raise ValueError( + f"TRT-LLM LoRA requires uniform KV heads across layers, " + f"got: {kv_heads_per_layer_raw}") + # Apply TP/CP scaling to each layer + num_kv_heads_per_layer = [ + kv_heads // (self.mapping.tp_size * self.mapping.cp_size) + for kv_heads in kv_heads_per_layer_raw + ] + hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size model_config_cpp = ModelConfigCpp( @@ -317,11 +360,10 @@ def get_bindings_model_config(self, else: model_config_cpp.tokens_per_block = tokens_per_block - # For kv cache size calculation: set num_kv_heads - num_kv_heads = getattr( - self.pretrained_config, "num_key_value_heads", - num_heads) // (self.mapping.tp_size * self.mapping.cp_size) - model_config_cpp.set_num_kv_heads(num_kv_heads) + if use_per_layer_kv_heads: + model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer + else: + model_config_cpp.set_num_kv_heads(num_kv_heads) mlp_hidden_size = None if self.pretrained_config.intermediate_size is not None: @@ -371,8 +413,10 @@ def _infer_nemotron_ffn_mult(self): # Nemotron-NAS has variable ffn_mult for each layer, we need to find the maximum # so that we don't set a too small mlp_hidden_size. This solution leads to a memory # consumption that is higher than required. - biggest_ffn_mult = max( - [x.ffn.ffn_mult for x in self.pretrained_config.block_configs]) + biggest_ffn_mult = max([ + (x.ffn.ffn_mult if x.ffn.ffn_mult is not None else 0) + for x in self.pretrained_config.block_configs + ]) from tensorrt_llm._torch.models.modeling_nemotron_nas import \ _ffn_mult_to_intermediate_size diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index aeecff7c3e0..33dddfc784c 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -703,11 +703,13 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): model_config, 'lora_config') and model_config.lora_config is not None and len( model_config.lora_config.lora_dir) == 1: - lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) - if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: - vocab_size = lora_loader.vocab_size - weight = lora_loader.embed_tokens - self.has_custom_embed_tokens = True + # Only check for custom vocab in HF LoRA, not NeMo + if model_config.lora_config.lora_ckpt_source == "hf": + lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) + if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: + vocab_size = lora_loader.vocab_size + weight = lora_loader.embed_tokens + self.has_custom_embed_tokens = True if self.model_config.mapping.enable_attention_dp: self.embed_tokens = Embedding( diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py index 146d13f16f1..3ab1cdb37ca 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @@ -192,11 +192,13 @@ def __init__(self, model_config): model_config, 'lora_config') and model_config.lora_config is not None and len( model_config.lora_config.lora_dir) == 1: - lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) - if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: - vocab_size = lora_loader.vocab_size - weight = lora_loader.embed_tokens - self.has_custom_embed_tokens = True + # Only check for custom vocab in HF LoRA, not NeMo + if model_config.lora_config.lora_ckpt_source == "hf": + lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) + if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: + vocab_size = lora_loader.vocab_size + weight = lora_loader.embed_tokens + self.has_custom_embed_tokens = True self.embed_tokens = Embedding( vocab_size, diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index c751bdcbb01..5b28d379206 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -364,11 +364,13 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig], if (hasattr(config, 'lora_config') and config.lora_config is not None and len(config.lora_config.lora_dir) == 1): - lora_loader = HfLoraLoader(config.lora_config.lora_dir) - if lora_loader.lm_head is not None and lora_loader.vocab_size != 0: - weight = lora_loader.lm_head - self.has_custom_lm_head = True - vocab_size = lora_loader.vocab_size + # Only check for custom lm_head in HF LoRA, not NeMo + if config.lora_config.lora_ckpt_source == "hf": + lora_loader = HfLoraLoader(config.lora_config.lora_dir) + if lora_loader.lm_head is not None and lora_loader.vocab_size != 0: + weight = lora_loader.lm_head + self.has_custom_lm_head = True + vocab_size = lora_loader.vocab_size self.lm_head = LMHead( vocab_size, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index adebecc1633..f5415b56de8 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -14,7 +14,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import (LoraConfig, get_default_trtllm_modules_to_hf_modules, - load_torch_hf_lora) + load_torch_lora) from tensorrt_llm.mapping import Mapping from ..model_config import ModelConfig @@ -437,7 +437,8 @@ def create_py_executor_instance( from tensorrt_llm.bindings import LoraModule if len(lora_config.lora_dir) == 1: - load_torch_hf_lora(lora_config) + # Route to appropriate loader based on checkpoint source + load_torch_lora(lora_config) else: assert len(lora_config.lora_target_modules ) >= 1, "Expecting at least one lora target module" @@ -450,12 +451,25 @@ def create_py_executor_instance( num_experts = _try_infer_num_experts(model_engine.model.model_config) + num_attn_layers = model_binding_config.num_attention_layers() + per_layer_kv_heads = [ + model_binding_config.num_kv_heads(i) for i in range(num_attn_layers) + ] + num_kv_attention_heads = max(per_layer_kv_heads) + if len(set(per_layer_kv_heads)) > 1: + # NOTE: This code-path is currently untested and not validated. Can fail! + # This support is tracked in TRTLLM-6561 + logger.warning( + f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. " + "This code-path is currently untested and not validated. May fail!" + ) + lora_modules = LoraModule.create_lora_modules( lora_module_names=lora_config.lora_target_modules, hidden_size=model_binding_config.hidden_size, mlp_hidden_size=model_binding_config.mlp_hidden_size, num_attention_heads=model_binding_config.num_heads, - num_kv_attention_heads=model_binding_config.num_heads, + num_kv_attention_heads=num_kv_attention_heads, attention_head_size=model_binding_config.head_size, tp_size=mapping.tp_size, num_experts=num_experts) diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index 886831d0723..52e3d8773e1 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -25,10 +25,15 @@ class LoRARequest: lora_name: str lora_int_id: int lora_path: str = "" + lora_ckpt_source: str = "hf" def __post_init__(self): if self.lora_path is not None and not os.path.exists(self.lora_path): raise ValueError(f"lora_path ({self.lora_path}) does not exist.") + if self.lora_ckpt_source not in ["hf", "nemo"]: + raise ValueError( + f"lora_ckpt_source must be 'hf' or 'nemo', got '{self.lora_ckpt_source}'" + ) @property def adapter_id(self): @@ -42,6 +47,10 @@ def name(self): def path(self): return self.lora_path + @property + def ckpt_source(self): + return self.lora_ckpt_source + @dataclass(slots=True) class PromptAdapterRequest: diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index aa793d30ea6..6ebd7adc03d 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -359,7 +359,8 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: model_config=self._runtime_model_config if self._runtime_model_config is not None else self._lora_model_config, runtime_mapping=None, - uids=[adapter_id]) + uids=[adapter_id], + ckpt_source=lora_request.ckpt_source) return adapter_id in newly_loaded_uids def _load_prompt_adapter(self, diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 3f87286024b..9f42fdad20d 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -4,8 +4,9 @@ import tarfile from collections import defaultdict from dataclasses import dataclass, field +from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -22,8 +23,21 @@ from .runtime import ModelConfig -def get_all_nemo_lora_weights(lora_weights): - layer_weights = defaultdict(dict) +def get_all_nemo_lora_weights( + lora_weights: Dict[str, torch.Tensor], +) -> Dict[int, Dict[str, torch.Tensor]]: + """Extract and organize NeMo LoRA weights by layer and direction. + + Args: + lora_weights: Dictionary mapping weight keys to tensors from NeMo checkpoint + + Returns: + Dictionary mapping layer_idx -> {direction -> tensor} where direction is 'in' or 'out' + + Raises: + KeyError: If unsupported keys are found or layer extraction fails + """ + layer_weights: Dict[int, Dict[str, torch.Tensor]] = defaultdict(dict) adapter_key = "self_attention.adapter_layer.lora_kqv_adapter" layer_pattern = re.compile(r".*\.layers\.(\d+)\..*") for key, weights in lora_weights.items(): @@ -52,7 +66,28 @@ def get_all_nemo_lora_weights(lora_weights): ) -def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None): +def iterate_hf_lora( + iter_fn, + lora_weights: Dict[str, torch.Tensor], + hf_modules: Set[str], + component: Optional[str] = None, +): + """Iterate over HuggingFace LoRA weights and call iterator function for each weight. + + Args: + iter_fn: Function to call for each weight with signature + (layer_idx, hf_module, expert_idx, inout_or_mag, weights) + lora_weights: Dictionary mapping weight keys to tensors from HF checkpoint + hf_modules: Set of supported HF module names + component: Optional component name to filter by (e.g., 'decoder') + + Returns: + Nested dictionary structure organizing the weights + + Raises: + KeyError: If unsupported keys are found + AssertionError: If HF module is not in supported list + """ all_weights = defaultdict(lambda: defaultdict(dict)) pattern = HF_LORA_PATTERN for key, weights in lora_weights.items(): @@ -96,7 +131,20 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None): return all_weights -def get_all_hf_lora_weights(lora_weights, hf_modules, component=None): +def get_all_hf_lora_weights( + lora_weights: Dict[str, torch.Tensor], hf_modules: Set[str], component: Optional[str] = None +): + """Extract and organize all HuggingFace LoRA weights by layer and module. + + Args: + lora_weights: Dictionary mapping weight keys to tensors from HF checkpoint + hf_modules: Set of supported HF module names + component: Optional component name to filter by (e.g., 'decoder') + + Returns: + Nested dictionary organizing weights by layer, module, and potentially expert + """ + def iter_fn(layer_idx, hf_module, expert_idx, inout, weights): if expert_idx is None: all_weights[layer_idx][hf_module][inout] = weights @@ -118,8 +166,19 @@ def iter_fn(layer_idx, hf_module, expert_idx, inout, weights): return hf_target_modules -def invert_module_mapping(trtllm_modules_to_hf_modules): - hf_modules_to_trtllm_modules = {} +def invert_module_mapping( + trtllm_modules_to_hf_modules: Dict[str, Union[str, List[str]]], +) -> Dict[str, str]: + """Invert module mapping from TensorRT-LLM -> HF to HF -> TensorRT-LLM. + + Args: + trtllm_modules_to_hf_modules: Mapping from TensorRT-LLM module names to HF module names + (values can be strings or lists of strings) + + Returns: + Dictionary mapping HF module names to TensorRT-LLM module names + """ + hf_modules_to_trtllm_modules: Dict[str, str] = {} for k, hf_modules in trtllm_modules_to_hf_modules.items(): if isinstance(hf_modules, list): for hf_module in hf_modules: @@ -218,8 +277,88 @@ def get_target_modules(self, trtllm_modules_to_hf_modules): return list(lora_target_modules) +@lru_cache(maxsize=128) +def _find_nemo_files_single_path(lora_path: str) -> List[str]: + """Find .nemo files from a single path (file or directory). + + This function is cached per individual path to maximize cache efficiency + when the same paths appear in different collections. + + Args: + lora_path: A single path that can be either: + - Direct path to a .nemo file + - Directory containing .nemo files (will auto-detect *.nemo) + + Returns: + List[str]: List of paths to .nemo files found in this single path + + Raises: + ValueError: If path doesn't exist, no .nemo files found, or invalid file type + """ + path = Path(lora_path) + if not path.exists(): + raise ValueError(f"{path} does not exist") + + if path.is_file(): + if path.suffix == ".nemo": + return [str(path)] + else: + raise ValueError(f"{path} is not a .nemo file") + elif path.is_dir(): + nemo_files_in_dir = list(path.glob("*.nemo")) + if not nemo_files_in_dir: + raise ValueError(f"No .nemo files found in directory {path}") + return [str(f) for f in nemo_files_in_dir] + else: + raise ValueError(f"{path} is neither a file nor a directory") + + +def find_nemo_files(lora_dirs: List[str]) -> List[str]: + """Find all .nemo files from a list of directories or file paths. + + This function is optimized for repeated calls at generation time by using an internal LRU cache + on individual paths, which maximizes cache efficiency when the same paths + appear in different collections. + + Args: + lora_dirs: List of paths that can be either: + - Direct paths to .nemo files + - Directories containing .nemo files (will auto-detect *.nemo) + + Returns: + List[str]: List of paths to .nemo files + + Raises: + ValueError: If a path doesn't exist, no .nemo files are found in a directory + path, or a file path is of invalid file type + """ + if len(lora_dirs) == 0: + return [] + + all_nemo_files: List[str] = [] + for lora_path in lora_dirs: + nemo_files_for_path = _find_nemo_files_single_path(lora_path) + all_nemo_files.extend(nemo_files_for_path) + + if not all_nemo_files: + raise ValueError("No .nemo files found in the provided paths") + + return all_nemo_files + + class NemoLoraLoader: def __init__(self, lora_dirs: List[str]): + """Initialize NemoLoraLoader with paths to .nemo files or directories. + + Args: + lora_dirs: List of paths that can be either: + - Direct paths to .nemo files + - Directories containing .nemo files (will auto-detect *.nemo) + + Note: The parameter name 'lora_dirs' is misleading - it can accept both + directories and files. This is a design flaw that should be fixed + in a future version (e.g., rename to 'lora_paths'). + """ self.lora_target_modules = [] self.is_valid = False @@ -230,15 +369,28 @@ def __init__(self, lora_dirs: List[str]): path = Path(lora_file) if not path.exists(): raise ValueError(f"{path} does not exist") - if not path.is_file(): - raise ValueError(f"{path} is not a file") self.is_valid = True # Hardcoded since LoraManager only supports this case now self.lora_target_modules = ["attn_qkv"] + def get_target_modules(self): + """Get target modules for NeMo LoRA. + + Unlike the HF loader, this method does not accept trtllm_modules_to_hf_modules + as an argument since the module mapping is hardcoded for NeMo LoRA support. + + Returns: + List[str]: List of target module names supported by NeMo LoRA + """ + return self.lora_target_modules + def load_nemo_lora(model, lora_config: LoraConfig): lora_loader = NemoLoraLoader(lora_config.lora_dir) + + if not lora_loader.is_valid: + raise ValueError(f"Failed to load NeMo LoRA from {lora_config.lora_dir}") + if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.lora_target_modules @@ -287,6 +439,73 @@ def load_torch_hf_lora(lora_config: LoraConfig): lora_config.lora_target_modules.extend(missing_qkv_modules) +def load_torch_nemo_lora(lora_config: LoraConfig): + """Load NeMo LoRA checkpoint for PyTorch workflow. + + This is a PyTorch-specific loader for NeMo LoRA checkpoints, similar to + load_torch_hf_lora but handling NeMo checkpoint format. NeMo uses a combined + "attn_qkv" module rather than separate Q, K, V modules, so no missing QKV + module handling is needed. + + Note: This function only sets up the configuration. For PyTorch workflow, + the actual weight loading happens later via LoraManager when requests are + made with LoRA UIDs. + + Args: + lora_config: LoRA configuration with lora_ckpt_source="nemo" + + Raises: + ValueError: If NeMo LoRA directory is invalid or unsupported modules are specified + """ + lora_config.trtllm_modules_to_hf_modules = {"attn_qkv": "attn_qkv"} + + assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir" + lora_loader = NemoLoraLoader(lora_config.lora_dir) + + if not lora_loader.is_valid: + raise ValueError(f"Failed to load NeMo LoRA from {lora_config.lora_dir}") + + if len(lora_config.lora_target_modules) == 0: + lora_config.lora_target_modules = lora_loader.get_target_modules() + + if len(lora_config.lora_target_modules) == 0: + raise ValueError( + "lora_target_modules is empty. " + "Please specify lora_target_modules or provide lora_dir to infer lora_target_modules." + ) + + supported_modules = {"attn_qkv"} + unsupported_modules = set(lora_config.lora_target_modules) - supported_modules + if unsupported_modules: + raise ValueError( + f"NeMo LoRA only supports {supported_modules} modules, " + f"but got unsupported modules: {unsupported_modules}. " + f"NeMo LoRA does not support embedding, lm_head, or MLP adapters." + ) + + +def load_torch_lora(lora_config: LoraConfig): + """Load LoRA checkpoint for PyTorch workflow. + + This function routes to the appropriate loader based on lora_ckpt_source. + + Args: + lora_config: LoRA configuration with lora_ckpt_source set to "hf" or "nemo" + + Raises: + ValueError: If lora_ckpt_source is not supported + """ + if lora_config.lora_ckpt_source == "nemo": + load_torch_nemo_lora(lora_config) + elif lora_config.lora_ckpt_source == "hf": + load_torch_hf_lora(lora_config) + else: + raise ValueError( + f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}. " + f"Supported sources: 'hf', 'nemo'" + ) + + def load_hf_lora( model, lora_config: LoraConfig, @@ -388,7 +607,18 @@ def use_lora( raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") -def unpack_nemo_weights(nemo_archive_path): +def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]: + """Unpack model config and weights from a NeMo .nemo archive file. + + Args: + nemo_archive_path: Path to the .nemo archive file + + Returns: + Tuple of (model_config_dict, model_weights_dict) + + Raises: + Exception: If required files cannot be extracted from the archive + """ with tarfile.open(nemo_archive_path) as tar: try: model_weights_file = tar.extractfile("model_weights.ckpt") @@ -539,8 +769,12 @@ def load_from_ckpt( uids=uids, ) elif ckpt_source == "nemo": + # Find all .nemo files from directories or files + nemo_files = find_nemo_files(model_dirs_or_files) + + # Pass the actual .nemo files to the loader return self.load_from_nemo( - model_files=model_dirs_or_files, + model_files=nemo_files, model_config=model_config, runtime_mapping=runtime_mapping, uids=uids, diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index 1b2323804fa..58673aa0699 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -1,5 +1,10 @@ +import json +import tarfile +import tempfile +from pathlib import Path from typing import OrderedDict, Type +import torch from utils.llm_data import llm_models_root from utils.util import duplicate_list_to_length, flatten_list, similar @@ -114,3 +119,116 @@ def check_llama_7b_multi_lora_from_request_test_harness( for output, ref, key_word in zip(outputs, references, key_words): assert similar(output.outputs[0].text, ref) or key_word in output.outputs[0].text + + +def create_mock_nemo_lora_checkpoint( + lora_dir: Path, + hidden_size: int = 4096, + num_layers: int = 32, + lora_rank: int = 8, + tp_size: int = 1, + num_attention_heads: int = 32, + num_kv_heads: int = None, # If None, defaults to num_attention_heads + dtype: torch.dtype = torch.float16, + seed: int = None, # For deterministic weight initialization +) -> Path: + """Create a minimal NeMo LoRA checkpoint for testing. + + This creates a .nemo tarfile with the expected structure: + - model_weights.ckpt containing attn_qkv adapter weights + - model_config.yaml with basic configuration + + Args: + lora_dir: Directory to create the checkpoint in + hidden_size: Model hidden size + num_layers: Number of transformer layers + lora_rank: LoRA rank + tp_size: Tensor parallelism size + num_attention_heads: Number of query attention heads + num_kv_heads: Number of key/value heads (for GQA). If None, equals num_attention_heads + dtype: Data type for the weights (default: torch.float16) + + Returns: + Path to the created .nemo file + """ + + # Validate parameters + if hidden_size % num_attention_heads != 0: + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by " + f"num_attention_heads ({num_attention_heads})") + + # Default to standard MHA if not specified + if num_kv_heads is None: + num_kv_heads = num_attention_heads + + if num_attention_heads % num_kv_heads != 0: + raise ValueError( + f"num_attention_heads ({num_attention_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads}) for GQA") + + nemo_path = lora_dir / "test_lora.nemo" + + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + + # Set random seed for deterministic weight initialization + if seed is not None: + torch.manual_seed(seed) + + weights_dict = {} + + head_dim = hidden_size // num_attention_heads + kv_hidden_size = head_dim * num_kv_heads + + qkv_output_dim = hidden_size + 2 * kv_hidden_size + + # NOTE: + # for seed=42, and coefficient=0.02, the expected outputs are hardcoded + # in the test `test_llm_pytorch.py::test_gqa_nemo_lora`. + # Therefore changing "WEIGHTS_COEFFICIENT" or the seed will break the test. + WEIGHTS_COEFFICIENT = 0.02 + for layer_idx in range(num_layers): + key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" + + # Create linear_in weights [lora_rank, hidden_size] with small random values + linear_in_key = f"{key_prefix}.linear_in.weight" + weights_dict[linear_in_key] = torch.randn( + lora_rank, hidden_size, dtype=dtype) * WEIGHTS_COEFFICIENT + + # Create linear_out weights [qkv_output_dim, lora_rank] for fused QKV + # This is the key difference for GQA - the output dimension changes + linear_out_key = f"{key_prefix}.linear_out.weight" + weights_dict[linear_out_key] = torch.randn( + qkv_output_dim, lora_rank, dtype=dtype) * WEIGHTS_COEFFICIENT + + ckpt_path = temp_dir / "model_weights.ckpt" + torch.save(weights_dict, ckpt_path) + + config = { + "precision": "fp16" if dtype == torch.float16 else "bf16", + "trainer": { + "num_nodes": 1, + "devices": tp_size, + }, + "model": { + "hidden_size": hidden_size, + "num_layers": num_layers, + "num_attention_heads": num_attention_heads, + "num_query_groups": num_kv_heads, # This is the key for GQA + }, + "lora": { + "rank": lora_rank, + "target_modules": ["attn_qkv"], + } + } + + config_path = temp_dir / "model_config.yaml" + # Using JSON for simplicity since YAML parsing isn't critical for the test + with open(config_path, 'w') as f: + json.dump(config, f) + + with tarfile.open(nemo_path, 'w') as tar: + tar.add(ckpt_path, arcname="model_weights.ckpt") + tar.add(config_path, arcname="model_config.yaml") + + return nemo_path diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 486ceb301f5..7e890693e50 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -5,7 +5,7 @@ from tensorrt_llm.sampling_params import SamplingParams # isort: off -from .lora_test_utils import check_llama_7b_multi_unique_lora_adapters_from_request +from .lora_test_utils import check_llama_7b_multi_unique_lora_adapters_from_request, create_mock_nemo_lora_checkpoint from .test_llm import (get_model_path, global_kvcache_config, llama_model_path, llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts, @@ -427,3 +427,141 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: lora_request=lora_requests) assert len(outputs) == 2 + + +@pytest.mark.parametrize( + "lora_rank,max_lora_rank,description", + [ + # (lora_rank, max_lora_rank, description) + (8, 8, "rank_8"), + (16, 16, "rank_16"), + (4, 8, "rank_4_max_8"), + ]) +def test_load_torch_nemo_lora_function(tmp_path, lora_rank, max_lora_rank, + description): + """Test load_torch_nemo_lora function with different LoRA rank configurations.""" + from tensorrt_llm.lora_manager import load_torch_nemo_lora + + nemo_path = create_mock_nemo_lora_checkpoint( + tmp_path, + hidden_size=2048, + num_layers=16, + lora_rank=lora_rank, + ) + + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=max_lora_rank, + ) + + # This should not raise an error + load_torch_nemo_lora(lora_config) + + assert lora_config.lora_target_modules == [ + "attn_qkv" + ], f"Expected attn_qkv modules for {description}" + assert lora_config.trtllm_modules_to_hf_modules == { + "attn_qkv": "attn_qkv" + }, f"Expected correct module mapping for {description}" + + +def test_nemo_lora_unsupported_modules_validation(tmp_path): + """Test validation of unsupported modules in NeMo LoRA.""" + from tensorrt_llm.lora_manager import load_torch_nemo_lora + + nemo_path = create_mock_nemo_lora_checkpoint( + tmp_path, + hidden_size=2048, + num_layers=16, + lora_rank=8, + ) + + # Test validation: should fail with unsupported modules + invalid_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + lora_target_modules=["attn_qkv", + "mlp_h_to_4h"], # mlp_h_to_4h not supported + max_lora_rank=8, + ) + + with pytest.raises(ValueError, match="NeMo LoRA only supports"): + load_torch_nemo_lora(invalid_config) + + +@force_ampere +def test_gqa_nemo_lora(tmp_path): + """ + Test NeMo-format LoRA checkpoint loading and GQA support in TinyLlama. + + This test verifies two properties: + 1. That a NeMo-format LoRA checkpoint with GQA (grouped query attention) can be loaded and applied to a TinyLlama model, + and that generation with this LoRA produces a deterministic, expected output for a fixed prompt and temperature=0.0. + 2. That the LoRA weights have a significant effect: generating with LoRA produces a different output than generating + without LoRA, confirming that the LoRA adapter is actually being applied. + + The test uses a deterministic dummy LoRA checkpoint (seed=42) and checks both the positive (LoRA applied) and negative + (no LoRA) cases for output text. + """ + # TinyLlama's exact GQA configuration + hidden_size = 2048 + num_layers = 22 + num_q_heads = 32 # Query attention heads + num_kv_heads = 4 # Key/Value heads (GQA) + lora_rank = 8 + + nemo_path = create_mock_nemo_lora_checkpoint( + tmp_path, + hidden_size=hidden_size, + num_layers=num_layers, + lora_rank=lora_rank, + num_attention_heads=num_q_heads, + num_kv_heads=num_kv_heads, + seed=42, # NOTE: the seed=42 is important for the test to pass. + ) + expected_lora_text_output = "Paris. The capital of France is Paris. The" + test_prompts = ["The capital of France is"] + sampling_params = SamplingParams(max_tokens=10, temperature=0.0) + + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=lora_rank, + ) + + model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + + llm = LLM( + model=model_path, + lora_config=lora_config, + kv_cache_config=global_kvcache_config, + ) + + try: + lora_req = LoRARequest("tinyllama-gqa-test", + 0, + str(nemo_path), + lora_ckpt_source="nemo") + + lora_outputs = llm.generate(test_prompts, + sampling_params, + lora_request=[lora_req]) + + # For the above deterministic dummy LoRA checkpoint, + # with temperature=0.0, + # the expected output text should always be the same. + assert lora_outputs[0].outputs[0].text == expected_lora_text_output, \ + f"Expected output text: {expected_lora_text_output}, " \ + f"got: {lora_outputs[0].outputs[0].text}" + assert len(lora_outputs) == 1 + + # Generate without LoRA. + # The LoRA weights are tuned/large enough that + # they differ from a no-LoRA run. + base_outputs = llm.generate(test_prompts, sampling_params) + assert base_outputs[0].outputs[0].text != expected_lora_text_output, \ + f"No-LoRA output should differ from expected output text: {expected_lora_text_output}, " \ + f"got: {base_outputs[0].outputs[0].text}" + finally: + llm.shutdown()