From ab714cae1b71c38a6697e0fe84b6310d910702b7 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Mon, 14 Jul 2025 12:23:34 -0700 Subject: [PATCH 01/16] Prototype the loading, add unit, integration tests Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_llama.py | 12 +- .../_torch/models/modeling_nemotron_nas.py | 12 +- tensorrt_llm/_torch/models/modeling_utils.py | 12 +- tensorrt_llm/_torch/pyexecutor/_util.py | 7 +- tensorrt_llm/lora_manager.py | 52 +++ .../defs/llmapi/test_llm_pytorch_nemo_lora.py | 296 ++++++++++++++++++ .../llmapi/test_llm_pytorch_nemo_lora.py | 238 ++++++++++++++ 7 files changed, 613 insertions(+), 16 deletions(-) create mode 100644 tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py create mode 100644 tests/unittest/llmapi/test_llm_pytorch_nemo_lora.py diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index aeecff7c3e0..33dddfc784c 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -703,11 +703,13 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): model_config, 'lora_config') and model_config.lora_config is not None and len( model_config.lora_config.lora_dir) == 1: - lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) - if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: - vocab_size = lora_loader.vocab_size - weight = lora_loader.embed_tokens - self.has_custom_embed_tokens = True + # Only check for custom vocab in HF LoRA, not NeMo + if model_config.lora_config.lora_ckpt_source == "hf": + lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) + if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: + vocab_size = lora_loader.vocab_size + weight = lora_loader.embed_tokens + self.has_custom_embed_tokens = True if self.model_config.mapping.enable_attention_dp: self.embed_tokens = Embedding( diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py index 146d13f16f1..3ab1cdb37ca 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @@ -192,11 +192,13 @@ def __init__(self, model_config): model_config, 'lora_config') and model_config.lora_config is not None and len( model_config.lora_config.lora_dir) == 1: - lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) - if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: - vocab_size = lora_loader.vocab_size - weight = lora_loader.embed_tokens - self.has_custom_embed_tokens = True + # Only check for custom vocab in HF LoRA, not NeMo + if model_config.lora_config.lora_ckpt_source == "hf": + lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) + if lora_loader.vocab_size != 0 and lora_loader.embed_tokens is not None: + vocab_size = lora_loader.vocab_size + weight = lora_loader.embed_tokens + self.has_custom_embed_tokens = True self.embed_tokens = Embedding( vocab_size, diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index c751bdcbb01..5b28d379206 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -364,11 +364,13 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig], if (hasattr(config, 'lora_config') and config.lora_config is not None and len(config.lora_config.lora_dir) == 1): - lora_loader = HfLoraLoader(config.lora_config.lora_dir) - if lora_loader.lm_head is not None and lora_loader.vocab_size != 0: - weight = lora_loader.lm_head - self.has_custom_lm_head = True - vocab_size = lora_loader.vocab_size + # Only check for custom lm_head in HF LoRA, not NeMo + if config.lora_config.lora_ckpt_source == "hf": + lora_loader = HfLoraLoader(config.lora_config.lora_dir) + if lora_loader.lm_head is not None and lora_loader.vocab_size != 0: + weight = lora_loader.lm_head + self.has_custom_lm_head = True + vocab_size = lora_loader.vocab_size self.lm_head = LMHead( vocab_size, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index adebecc1633..fc19b8f2520 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -437,7 +437,12 @@ def create_py_executor_instance( from tensorrt_llm.bindings import LoraModule if len(lora_config.lora_dir) == 1: - load_torch_hf_lora(lora_config) + # Route to appropriate loader based on checkpoint source + if lora_config.lora_ckpt_source == "nemo": + from tensorrt_llm.lora_manager import load_torch_nemo_lora + load_torch_nemo_lora(lora_config) + else: + load_torch_hf_lora(lora_config) else: assert len(lora_config.lora_target_modules ) >= 1, "Expecting at least one lora target module" diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 3f87286024b..514098d0200 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -236,6 +236,10 @@ def __init__(self, lora_dirs: List[str]): # Hardcoded since LoraManager only supports this case now self.lora_target_modules = ["attn_qkv"] + def get_target_modules(self, trtllm_modules_to_hf_modules): + """Get target modules for NeMo LoRA.""" + return self.lora_target_modules + def load_nemo_lora(model, lora_config: LoraConfig): lora_loader = NemoLoraLoader(lora_config.lora_dir) @@ -287,6 +291,54 @@ def load_torch_hf_lora(lora_config: LoraConfig): lora_config.lora_target_modules.extend(missing_qkv_modules) +def load_torch_nemo_lora(lora_config: LoraConfig): + """Load NeMo LoRA checkpoint for PyTorch workflow. + + This is a PyTorch-specific loader for NeMo LoRA checkpoints, similar to + load_torch_hf_lora but handling NeMo checkpoint format. + + Args: + lora_config: LoRA configuration with lora_ckpt_source="nemo" + """ + # For NeMo, we need to set up module mappings differently + # NeMo uses "attn_qkv" as a combined module + lora_config.trtllm_modules_to_hf_modules = {"attn_qkv": "attn_qkv"} + + assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir" + lora_loader = NemoLoraLoader(lora_config.lora_dir) + + if not lora_loader.is_valid: + raise ValueError(f"Failed to load NeMo LoRA from {lora_config.lora_dir}") + + if len(lora_config.lora_target_modules) == 0: + lora_config.lora_target_modules = lora_loader.get_target_modules( + lora_config.trtllm_modules_to_hf_modules + ) + + if len(lora_config.lora_target_modules) == 0: + raise ValueError( + "lora_target_modules is empty. " + "Please specify lora_target_modules or provide lora_dir to infer lora_target_modules." + ) + + # Validate that NeMo LoRA only supports attn_qkv + supported_modules = {"attn_qkv"} + unsupported_modules = set(lora_config.lora_target_modules) - supported_modules + if unsupported_modules: + raise ValueError( + f"NeMo LoRA only supports {supported_modules} modules, " + f"but got unsupported modules: {unsupported_modules}. " + f"NeMo LoRA does not support embedding, lm_head, or MLP adapters." + ) + + # NeMo only supports attn_qkv currently, no need for missing QKV module handling + # as it's already combined + + # Note: For PyTorch workflow, the actual weight loading happens later + # via LoraManager when requests are made with LoRA UIDs. This function + # just sets up the configuration. + + def load_hf_lora( model, lora_config: LoraConfig, diff --git a/tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py b/tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py new file mode 100644 index 00000000000..98beebd7c69 --- /dev/null +++ b/tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py @@ -0,0 +1,296 @@ +"""Integration tests for NeMo LoRA checkpoint loading in PyTorch workflow.""" + +import json +import tarfile +import tempfile +from pathlib import Path + +import pytest +import torch +from defs.conftest import llm_models_root + +from tensorrt_llm import LLM +from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.sampling_params import SamplingParams + +# needed since we reuse the mpi executor pool, first test running will leak a thread +pytestmark = pytest.mark.threadleak(enabled=False) + + +def create_mock_nemo_lora_checkpoint( + lora_dir: Path, + hidden_size: int = 4096, + num_layers: int = 32, + lora_rank: int = 8, + tp_size: int = 1, +) -> Path: + """Create a minimal NeMo LoRA checkpoint for testing. + + This creates a .nemo tarfile with the expected structure: + - model_weights.ckpt containing attn_qkv adapter weights + - model_config.yaml with basic configuration + + Args: + lora_dir: Directory to create the checkpoint in + hidden_size: Model hidden size + num_layers: Number of transformer layers + lora_rank: LoRA rank + tp_size: Tensor parallelism size + + Returns: + Path to the created .nemo file + """ + # Create temporary directory for checkpoint contents + temp_dir = lora_dir / "temp_nemo" + temp_dir.mkdir(exist_ok=True) + + # Create LoRA weights dict + weights_dict = {} + + for layer_idx in range(num_layers): + # NeMo uses this key format for QKV adapters + key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" + + # Create linear_in weights [lora_rank, hidden_size] + linear_in_key = f"{key_prefix}.linear_in.weight" + weights_dict[linear_in_key] = torch.zeros(lora_rank, + hidden_size, + dtype=torch.float16) + + # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined + linear_out_key = f"{key_prefix}.linear_out.weight" + weights_dict[linear_out_key] = torch.zeros(3 * hidden_size, + lora_rank, + dtype=torch.float16) + + # Save checkpoint + ckpt_path = temp_dir / "model_weights.ckpt" + torch.save(weights_dict, ckpt_path) + + # Create minimal config + config = { + "precision": "fp16", + "trainer": { + "num_nodes": 1, + "devices": tp_size, + }, + "model": { + "hidden_size": hidden_size, + "num_layers": num_layers, + }, + "lora": { + "rank": lora_rank, + "target_modules": ["attn_qkv"], + } + } + + config_path = temp_dir / "model_config.yaml" + # Using JSON for simplicity since YAML parsing isn't critical for the test + with open(config_path, 'w') as f: + json.dump(config, f) + + # Create .nemo tarfile + nemo_path = lora_dir / "test_lora.nemo" + with tarfile.open(nemo_path, 'w') as tar: + tar.add(ckpt_path, arcname="model_weights.ckpt") + tar.add(config_path, arcname="model_config.yaml") + + # Cleanup temp dir + import shutil + shutil.rmtree(temp_dir) + + return nemo_path + + +# Test data for parametrized tests +LORA_RANK_CONFIGS = [ + # (lora_rank, max_lora_rank, description) + (8, 8, "rank_8"), + (16, 16, "rank_16"), + (4, 8, "rank_4_max_8"), +] + + +class TestNemoLoraIntegration: + """Integration tests for NeMo LoRA with full model initialization.""" + + @pytest.mark.parametrize("lora_rank,max_lora_rank,description", + LORA_RANK_CONFIGS) + def test_llama_nemo_lora_inference(self, lora_rank, max_lora_rank, + description): + """Test NeMo LoRA inference with Llama model using different LoRA ranks.""" + model_dir = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-1B/" + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo LoRA checkpoint + nemo_lora_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, # Llama 3.2 1B hidden size + num_layers=16, # Llama 3.2 1B layer count + lora_rank=lora_rank, + ) + + # Configure LoRA with nemo source + lora_config = LoraConfig( + lora_dir=[str(nemo_lora_path)], + lora_ckpt_source="nemo", + lora_target_modules=["attn_qkv"], + max_lora_rank=max_lora_rank, + ) + + # Create LLM instance with LoRA + llm = LLM( + model=model_dir, + lora_config=lora_config, + backend="pytorch", + ) + + try: + # Test inference with LoRA + prompts = ["Hello, how are you?"] + sampling_params = SamplingParams(max_tokens=10) + + outputs = llm.generate(prompts, sampling_params) + + # Basic validation - should generate something + assert len(outputs) == 1, f"Expected 1 output for {description}" + assert len(outputs[0].outputs[0].text + ) > 0, f"Expected non-empty text for {description}" + + print( + f"[{description}] Generated text: {outputs[0].outputs[0].text}" + ) + finally: + # Ensure proper cleanup + del llm + import gc + gc.collect() + + @pytest.mark.parametrize("prompt,max_tokens,description", [ + ("Hello, how are you?", 10, "greeting_short"), + ("The weather today is", 20, "weather_medium"), + ("Tell me about", 15, "question_medium"), + ]) + def test_llama_nemo_lora_different_prompts(self, prompt, max_tokens, + description): + """Test NeMo LoRA with different prompts and generation lengths.""" + model_dir = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-1B/" + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo LoRA checkpoint + nemo_lora_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=16, + lora_rank=8, + ) + + # Configure LoRA + lora_config = LoraConfig( + lora_dir=[str(nemo_lora_path)], + lora_ckpt_source="nemo", + lora_target_modules=["attn_qkv"], + max_lora_rank=8, + ) + + # Create LLM instance with LoRA + llm = LLM( + model=model_dir, + lora_config=lora_config, + backend="pytorch", + ) + + try: + # Test inference with different prompts + prompts = [prompt] + sampling_params = SamplingParams(max_tokens=max_tokens) + + outputs = llm.generate(prompts, sampling_params) + + # Validation + assert len(outputs) == 1, f"Expected 1 output for {description}" + generated_text = outputs[0].outputs[0].text + assert len(generated_text + ) > 0, f"Expected non-empty text for {description}" + + # Basic sanity check - generated text should have reasonable length + assert len( + generated_text.split() + ) <= max_tokens + 5, f"Generated text too long for {description}" + + print( + f"[{description}] Prompt: '{prompt}' -> Generated: '{generated_text}'" + ) + finally: + # Ensure proper cleanup + del llm + import gc + gc.collect() + + +class TestNemoLoraTensorParallel: + """Tests for NeMo LoRA with tensor parallelism.""" + + @pytest.mark.parametrize("tp_size,description", [ + (2, "tp_2"), + (4, "tp_4"), + ]) + def test_llama_nemo_lora_tensor_parallel(self, tp_size, description): + """Test NeMo LoRA loading with different tensor parallelism sizes.""" + import torch + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Test requires at least {tp_size} GPUs") + + model_dir = f"{llm_models_root()}/llama-models-v3/llama-3.2-1b-hf" + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo LoRA checkpoint with specified TP size + nemo_lora_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=16, + lora_rank=8, + tp_size=tp_size, + ) + + # Configure LoRA + lora_config = LoraConfig( + lora_dir=[str(nemo_lora_path)], + lora_ckpt_source="nemo", + lora_target_modules=["attn_qkv"], + max_lora_rank=8, + ) + + # Create LLM instance with tensor parallelism + llm = LLM( + model=model_dir, + lora_config=lora_config, + backend="pytorch", + tensor_parallel_size=tp_size, + ) + + try: + # Test inference + prompts = ["The weather today is"] + sampling_params = SamplingParams(max_tokens=20) + + outputs = llm.generate(prompts, sampling_params) + + assert len(outputs) == 1, f"Expected 1 output for {description}" + assert len(outputs[0].outputs[0].text + ) > 0, f"Expected non-empty text for {description}" + + print( + f"[{description}] Generated text: {outputs[0].outputs[0].text}" + ) + finally: + # Ensure proper cleanup + del llm + import gc + gc.collect() diff --git a/tests/unittest/llmapi/test_llm_pytorch_nemo_lora.py b/tests/unittest/llmapi/test_llm_pytorch_nemo_lora.py new file mode 100644 index 00000000000..3ff1466faa9 --- /dev/null +++ b/tests/unittest/llmapi/test_llm_pytorch_nemo_lora.py @@ -0,0 +1,238 @@ +"""Tests for NeMo LoRA checkpoint loading in PyTorch workflow. + +This file contains fast unit tests that do not require full model initialization. +For integration tests that require full model loading and GPU inference, see: + tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py + +Unit tests here should run in seconds, not minutes. +""" + +import json +import tarfile +import tempfile +from pathlib import Path + +import pytest +import torch + +from tensorrt_llm.lora_manager import LoraConfig + + +def create_mock_nemo_lora_checkpoint( + lora_dir: Path, + hidden_size: int = 4096, + num_layers: int = 32, + lora_rank: int = 8, + tp_size: int = 1, +) -> Path: + """Create a minimal NeMo LoRA checkpoint for testing. + + This creates a .nemo tarfile with the expected structure: + - model_weights.ckpt containing attn_qkv adapter weights + - model_config.yaml with basic configuration + + Args: + lora_dir: Directory to create the checkpoint in + hidden_size: Model hidden size + num_layers: Number of transformer layers + lora_rank: LoRA rank + tp_size: Tensor parallelism size + + Returns: + Path to the created .nemo file + """ + # Create temporary directory for checkpoint contents + temp_dir = lora_dir / "temp_nemo" + temp_dir.mkdir(exist_ok=True) + + # Create LoRA weights dict + weights_dict = {} + + for layer_idx in range(num_layers): + # NeMo uses this key format for QKV adapters + key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" + + # Create linear_in weights [lora_rank, hidden_size] + linear_in_key = f"{key_prefix}.linear_in.weight" + weights_dict[linear_in_key] = torch.zeros(lora_rank, + hidden_size, + dtype=torch.float16) + + # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined + linear_out_key = f"{key_prefix}.linear_out.weight" + weights_dict[linear_out_key] = torch.zeros(3 * hidden_size, + lora_rank, + dtype=torch.float16) + + # Save checkpoint + ckpt_path = temp_dir / "model_weights.ckpt" + torch.save(weights_dict, ckpt_path) + + # Create minimal config + config = { + "precision": "fp16", + "trainer": { + "num_nodes": 1, + "devices": tp_size, + }, + "model": { + "hidden_size": hidden_size, + "num_layers": num_layers, + }, + "lora": { + "rank": lora_rank, + "target_modules": ["attn_qkv"], + } + } + + config_path = temp_dir / "model_config.yaml" + # Using JSON for simplicity since YAML parsing isn't critical for the test + with open(config_path, 'w') as f: + json.dump(config, f) + + # Create .nemo tarfile + nemo_path = lora_dir / "test_lora.nemo" + with tarfile.open(nemo_path, 'w') as tar: + tar.add(ckpt_path, arcname="model_weights.ckpt") + tar.add(config_path, arcname="model_config.yaml") + + # Cleanup temp dir + import shutil + shutil.rmtree(temp_dir) + + return nemo_path + + +# Test data for parametrized tests +NEMO_LORA_UNIT_TEST_PARAMS = [ + # (hidden_size, num_layers, lora_rank, description) + (2048, 16, 8, "small_model_rank_8"), + (4096, 32, 16, "large_model_rank_16"), + (1024, 12, 4, "tiny_model_rank_4"), +] + +LORA_RANK_CONFIGS = [ + # (lora_rank, max_lora_rank, description) + (8, 8, "rank_8"), + (16, 16, "rank_16"), + (4, 8, "rank_4_max_8"), +] + + +class TestNemoLoraUnit: + """Unit tests for NeMo LoRA loading without full model initialization.""" + + @pytest.mark.parametrize("hidden_size,num_layers,lora_rank,description", + NEMO_LORA_UNIT_TEST_PARAMS) + def test_nemo_lora_loader_creation(self, hidden_size, num_layers, lora_rank, + description): + """Test NemoLoraLoader creation with different model configurations.""" + from tensorrt_llm.lora_manager import NemoLoraLoader + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo checkpoint + nemo_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=hidden_size, + num_layers=num_layers, + lora_rank=lora_rank, + ) + + # Test NemoLoraLoader directly + loader = NemoLoraLoader([str(nemo_path)]) + assert loader.is_valid, f"NemoLoraLoader failed to validate {nemo_path} for {description}" + assert loader.lora_target_modules == [ + "attn_qkv" + ], f"Expected attn_qkv modules for {description}" + + @pytest.mark.parametrize("lora_rank,max_lora_rank,description", + LORA_RANK_CONFIGS) + def test_load_torch_nemo_lora_function(self, lora_rank, max_lora_rank, + description): + """Test load_torch_nemo_lora function with different LoRA rank configurations.""" + from tensorrt_llm.lora_manager import load_torch_nemo_lora + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo checkpoint + nemo_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=16, + lora_rank=lora_rank, + ) + + # Test load_torch_nemo_lora + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=max_lora_rank, + ) + + # This should not raise an error + load_torch_nemo_lora(lora_config) + + # Verify configuration was set correctly + assert lora_config.lora_target_modules == [ + "attn_qkv" + ], f"Expected attn_qkv modules for {description}" + assert lora_config.trtllm_modules_to_hf_modules == { + "attn_qkv": "attn_qkv" + }, f"Expected correct module mapping for {description}" + + def test_nemo_lora_unsupported_modules_validation(self): + """Test validation of unsupported modules in NeMo LoRA.""" + from tensorrt_llm.lora_manager import load_torch_nemo_lora + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo checkpoint + nemo_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=16, + lora_rank=8, + ) + + # Test validation: should fail with unsupported modules + invalid_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + lora_target_modules=["attn_qkv", "mlp_h_to_4h" + ], # mlp_h_to_4h not supported + max_lora_rank=8, + ) + + with pytest.raises(ValueError, match="NeMo LoRA only supports"): + load_torch_nemo_lora(invalid_config) + + def test_nemo_lora_empty_target_modules(self): + """Test NeMo LoRA with empty target modules list.""" + from tensorrt_llm.lora_manager import load_torch_nemo_lora + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo checkpoint + nemo_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=16, + lora_rank=8, + ) + + # Test with empty target modules - should auto-detect + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=8, + ) + + load_torch_nemo_lora(lora_config) + + # Should auto-detect and set attn_qkv + assert lora_config.lora_target_modules == ["attn_qkv"] From 34b5239e8ee1b99523ca2a22572bb5379b8c8c53 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Mon, 14 Jul 2025 12:37:47 -0700 Subject: [PATCH 02/16] modify test_lists Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_h100.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 3d115bc05b8..9bb305fc990 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -20,6 +20,7 @@ l0_h100: - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_nemotron" - unittest/_torch/modeling -k "modeling_gemma3" + - unittest/llmapi/test_llm_pytorch_nemo_lora.py - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py From 9c881a5b585a20e49e784a05b4eab13f52c6c30f Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Mon, 14 Jul 2025 17:27:58 -0700 Subject: [PATCH 03/16] fix bug, now accepts nemo lora_dir Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 74 ++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 514098d0200..5c0065d57c2 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -4,6 +4,7 @@ import tarfile from collections import defaultdict from dataclasses import dataclass, field +from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Union @@ -218,8 +219,69 @@ def get_target_modules(self, trtllm_modules_to_hf_modules): return list(lora_target_modules) +@lru_cache(maxsize=32) +def _find_nemo_files_cached(lora_dirs_tuple): + # Helper for caching: lora_dirs must be a tuple of strings + nemo_files = [] + + for lora_path in lora_dirs_tuple: + path = Path(lora_path) + if not path.exists(): + raise ValueError(f"{path} does not exist") + + if path.is_file(): + if path.suffix == ".nemo": + nemo_files.append(str(path)) + else: + raise ValueError(f"{path} is not a .nemo file") + elif path.is_dir(): + nemo_files_in_dir = list(path.glob("*.nemo")) + if not nemo_files_in_dir: + raise ValueError(f"No .nemo files found in directory {path}") + nemo_files.extend([str(f) for f in nemo_files_in_dir]) + else: + raise ValueError(f"{path} is neither a file nor a directory") + + if not nemo_files: + raise ValueError("No .nemo files found in the provided paths") + + return nemo_files + + +def find_nemo_files(lora_dirs: List[str]) -> List[str]: + """Find all .nemo files from a list of directories or file paths. + + This function is optimized for repeated calls by using an internal LRU cache. + + Args: + lora_dirs: List of paths that can be either: + - Direct paths to .nemo files + - Directories containing .nemo files (will auto-detect *.nemo) + + Returns: + List[str]: List of paths to .nemo files + + Raises: + ValueError: If path doesn't exist, no .nemo files found, or invalid file type + """ + if len(lora_dirs) == 0: + return [] + return _find_nemo_files_cached(tuple(lora_dirs)) + + class NemoLoraLoader: def __init__(self, lora_dirs: List[str]): + """Initialize NemoLoraLoader with paths to .nemo files or directories. + + Args: + lora_dirs: List of paths that can be either: + - Direct paths to .nemo files + - Directories containing .nemo files (will auto-detect *.nemo) + + Note: The parameter name 'lora_dirs' is misleading - it can accept both + directories and files. This is a design flaw that should be fixed + in a future version (e.g., rename to 'lora_paths'). + """ self.lora_target_modules = [] self.is_valid = False @@ -230,8 +292,6 @@ def __init__(self, lora_dirs: List[str]): path = Path(lora_file) if not path.exists(): raise ValueError(f"{path} does not exist") - if not path.is_file(): - raise ValueError(f"{path} is not a file") self.is_valid = True # Hardcoded since LoraManager only supports this case now self.lora_target_modules = ["attn_qkv"] @@ -243,6 +303,10 @@ def get_target_modules(self, trtllm_modules_to_hf_modules): def load_nemo_lora(model, lora_config: LoraConfig): lora_loader = NemoLoraLoader(lora_config.lora_dir) + + if not lora_loader.is_valid: + raise ValueError(f"Failed to load NeMo LoRA from {lora_config.lora_dir}") + if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.lora_target_modules @@ -591,8 +655,12 @@ def load_from_ckpt( uids=uids, ) elif ckpt_source == "nemo": + # Find all .nemo files from directories or files + nemo_files = find_nemo_files(model_dirs_or_files) + + # Pass the actual .nemo files to the loader return self.load_from_nemo( - model_files=model_dirs_or_files, + model_files=nemo_files, model_config=model_config, runtime_mapping=runtime_mapping, uids=uids, From 1a54d4b7067fe5c57afd85f1269a50edbd687553 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Mon, 14 Jul 2025 18:16:15 -0700 Subject: [PATCH 04/16] rename unittest to resolve pytest conflict Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_h100.yml | 2 +- ...{test_llm_pytorch_nemo_lora.py => test_pytorch_nemo_lora.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/unittest/llmapi/{test_llm_pytorch_nemo_lora.py => test_pytorch_nemo_lora.py} (100%) diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 9bb305fc990..41aee6197a2 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -20,7 +20,7 @@ l0_h100: - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_nemotron" - unittest/_torch/modeling -k "modeling_gemma3" - - unittest/llmapi/test_llm_pytorch_nemo_lora.py + - unittest/llmapi/test_pytorch_nemo_lora.py - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py diff --git a/tests/unittest/llmapi/test_llm_pytorch_nemo_lora.py b/tests/unittest/llmapi/test_pytorch_nemo_lora.py similarity index 100% rename from tests/unittest/llmapi/test_llm_pytorch_nemo_lora.py rename to tests/unittest/llmapi/test_pytorch_nemo_lora.py From 119c3a7d2c085011a587f96ca0b76921fd66c9e9 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 15 Jul 2025 09:16:11 -0700 Subject: [PATCH 05/16] add ckpt_source flag to loraRequest, simplify tests Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/executor/request.py | 9 + tensorrt_llm/executor/worker.py | 3 +- .../defs/llmapi/test_llm_pytorch_nemo_lora.py | 296 ------------------ .../test_lists/test-db/l0_h100.yml | 1 - tests/unittest/llmapi/test_llm_pytorch.py | 255 +++++++++++++++ .../unittest/llmapi/test_pytorch_nemo_lora.py | 238 -------------- 6 files changed, 266 insertions(+), 536 deletions(-) delete mode 100644 tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py delete mode 100644 tests/unittest/llmapi/test_pytorch_nemo_lora.py diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index 886831d0723..52e3d8773e1 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -25,10 +25,15 @@ class LoRARequest: lora_name: str lora_int_id: int lora_path: str = "" + lora_ckpt_source: str = "hf" def __post_init__(self): if self.lora_path is not None and not os.path.exists(self.lora_path): raise ValueError(f"lora_path ({self.lora_path}) does not exist.") + if self.lora_ckpt_source not in ["hf", "nemo"]: + raise ValueError( + f"lora_ckpt_source must be 'hf' or 'nemo', got '{self.lora_ckpt_source}'" + ) @property def adapter_id(self): @@ -42,6 +47,10 @@ def name(self): def path(self): return self.lora_path + @property + def ckpt_source(self): + return self.lora_ckpt_source + @dataclass(slots=True) class PromptAdapterRequest: diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index aa793d30ea6..6ebd7adc03d 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -359,7 +359,8 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: model_config=self._runtime_model_config if self._runtime_model_config is not None else self._lora_model_config, runtime_mapping=None, - uids=[adapter_id]) + uids=[adapter_id], + ckpt_source=lora_request.ckpt_source) return adapter_id in newly_loaded_uids def _load_prompt_adapter(self, diff --git a/tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py b/tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py deleted file mode 100644 index 98beebd7c69..00000000000 --- a/tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py +++ /dev/null @@ -1,296 +0,0 @@ -"""Integration tests for NeMo LoRA checkpoint loading in PyTorch workflow.""" - -import json -import tarfile -import tempfile -from pathlib import Path - -import pytest -import torch -from defs.conftest import llm_models_root - -from tensorrt_llm import LLM -from tensorrt_llm.lora_manager import LoraConfig -from tensorrt_llm.sampling_params import SamplingParams - -# needed since we reuse the mpi executor pool, first test running will leak a thread -pytestmark = pytest.mark.threadleak(enabled=False) - - -def create_mock_nemo_lora_checkpoint( - lora_dir: Path, - hidden_size: int = 4096, - num_layers: int = 32, - lora_rank: int = 8, - tp_size: int = 1, -) -> Path: - """Create a minimal NeMo LoRA checkpoint for testing. - - This creates a .nemo tarfile with the expected structure: - - model_weights.ckpt containing attn_qkv adapter weights - - model_config.yaml with basic configuration - - Args: - lora_dir: Directory to create the checkpoint in - hidden_size: Model hidden size - num_layers: Number of transformer layers - lora_rank: LoRA rank - tp_size: Tensor parallelism size - - Returns: - Path to the created .nemo file - """ - # Create temporary directory for checkpoint contents - temp_dir = lora_dir / "temp_nemo" - temp_dir.mkdir(exist_ok=True) - - # Create LoRA weights dict - weights_dict = {} - - for layer_idx in range(num_layers): - # NeMo uses this key format for QKV adapters - key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" - - # Create linear_in weights [lora_rank, hidden_size] - linear_in_key = f"{key_prefix}.linear_in.weight" - weights_dict[linear_in_key] = torch.zeros(lora_rank, - hidden_size, - dtype=torch.float16) - - # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined - linear_out_key = f"{key_prefix}.linear_out.weight" - weights_dict[linear_out_key] = torch.zeros(3 * hidden_size, - lora_rank, - dtype=torch.float16) - - # Save checkpoint - ckpt_path = temp_dir / "model_weights.ckpt" - torch.save(weights_dict, ckpt_path) - - # Create minimal config - config = { - "precision": "fp16", - "trainer": { - "num_nodes": 1, - "devices": tp_size, - }, - "model": { - "hidden_size": hidden_size, - "num_layers": num_layers, - }, - "lora": { - "rank": lora_rank, - "target_modules": ["attn_qkv"], - } - } - - config_path = temp_dir / "model_config.yaml" - # Using JSON for simplicity since YAML parsing isn't critical for the test - with open(config_path, 'w') as f: - json.dump(config, f) - - # Create .nemo tarfile - nemo_path = lora_dir / "test_lora.nemo" - with tarfile.open(nemo_path, 'w') as tar: - tar.add(ckpt_path, arcname="model_weights.ckpt") - tar.add(config_path, arcname="model_config.yaml") - - # Cleanup temp dir - import shutil - shutil.rmtree(temp_dir) - - return nemo_path - - -# Test data for parametrized tests -LORA_RANK_CONFIGS = [ - # (lora_rank, max_lora_rank, description) - (8, 8, "rank_8"), - (16, 16, "rank_16"), - (4, 8, "rank_4_max_8"), -] - - -class TestNemoLoraIntegration: - """Integration tests for NeMo LoRA with full model initialization.""" - - @pytest.mark.parametrize("lora_rank,max_lora_rank,description", - LORA_RANK_CONFIGS) - def test_llama_nemo_lora_inference(self, lora_rank, max_lora_rank, - description): - """Test NeMo LoRA inference with Llama model using different LoRA ranks.""" - model_dir = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-1B/" - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo LoRA checkpoint - nemo_lora_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, # Llama 3.2 1B hidden size - num_layers=16, # Llama 3.2 1B layer count - lora_rank=lora_rank, - ) - - # Configure LoRA with nemo source - lora_config = LoraConfig( - lora_dir=[str(nemo_lora_path)], - lora_ckpt_source="nemo", - lora_target_modules=["attn_qkv"], - max_lora_rank=max_lora_rank, - ) - - # Create LLM instance with LoRA - llm = LLM( - model=model_dir, - lora_config=lora_config, - backend="pytorch", - ) - - try: - # Test inference with LoRA - prompts = ["Hello, how are you?"] - sampling_params = SamplingParams(max_tokens=10) - - outputs = llm.generate(prompts, sampling_params) - - # Basic validation - should generate something - assert len(outputs) == 1, f"Expected 1 output for {description}" - assert len(outputs[0].outputs[0].text - ) > 0, f"Expected non-empty text for {description}" - - print( - f"[{description}] Generated text: {outputs[0].outputs[0].text}" - ) - finally: - # Ensure proper cleanup - del llm - import gc - gc.collect() - - @pytest.mark.parametrize("prompt,max_tokens,description", [ - ("Hello, how are you?", 10, "greeting_short"), - ("The weather today is", 20, "weather_medium"), - ("Tell me about", 15, "question_medium"), - ]) - def test_llama_nemo_lora_different_prompts(self, prompt, max_tokens, - description): - """Test NeMo LoRA with different prompts and generation lengths.""" - model_dir = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-1B/" - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo LoRA checkpoint - nemo_lora_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, - num_layers=16, - lora_rank=8, - ) - - # Configure LoRA - lora_config = LoraConfig( - lora_dir=[str(nemo_lora_path)], - lora_ckpt_source="nemo", - lora_target_modules=["attn_qkv"], - max_lora_rank=8, - ) - - # Create LLM instance with LoRA - llm = LLM( - model=model_dir, - lora_config=lora_config, - backend="pytorch", - ) - - try: - # Test inference with different prompts - prompts = [prompt] - sampling_params = SamplingParams(max_tokens=max_tokens) - - outputs = llm.generate(prompts, sampling_params) - - # Validation - assert len(outputs) == 1, f"Expected 1 output for {description}" - generated_text = outputs[0].outputs[0].text - assert len(generated_text - ) > 0, f"Expected non-empty text for {description}" - - # Basic sanity check - generated text should have reasonable length - assert len( - generated_text.split() - ) <= max_tokens + 5, f"Generated text too long for {description}" - - print( - f"[{description}] Prompt: '{prompt}' -> Generated: '{generated_text}'" - ) - finally: - # Ensure proper cleanup - del llm - import gc - gc.collect() - - -class TestNemoLoraTensorParallel: - """Tests for NeMo LoRA with tensor parallelism.""" - - @pytest.mark.parametrize("tp_size,description", [ - (2, "tp_2"), - (4, "tp_4"), - ]) - def test_llama_nemo_lora_tensor_parallel(self, tp_size, description): - """Test NeMo LoRA loading with different tensor parallelism sizes.""" - import torch - if torch.cuda.device_count() < tp_size: - pytest.skip(f"Test requires at least {tp_size} GPUs") - - model_dir = f"{llm_models_root()}/llama-models-v3/llama-3.2-1b-hf" - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo LoRA checkpoint with specified TP size - nemo_lora_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, - num_layers=16, - lora_rank=8, - tp_size=tp_size, - ) - - # Configure LoRA - lora_config = LoraConfig( - lora_dir=[str(nemo_lora_path)], - lora_ckpt_source="nemo", - lora_target_modules=["attn_qkv"], - max_lora_rank=8, - ) - - # Create LLM instance with tensor parallelism - llm = LLM( - model=model_dir, - lora_config=lora_config, - backend="pytorch", - tensor_parallel_size=tp_size, - ) - - try: - # Test inference - prompts = ["The weather today is"] - sampling_params = SamplingParams(max_tokens=20) - - outputs = llm.generate(prompts, sampling_params) - - assert len(outputs) == 1, f"Expected 1 output for {description}" - assert len(outputs[0].outputs[0].text - ) > 0, f"Expected non-empty text for {description}" - - print( - f"[{description}] Generated text: {outputs[0].outputs[0].text}" - ) - finally: - # Ensure proper cleanup - del llm - import gc - gc.collect() diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 41aee6197a2..3d115bc05b8 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -20,7 +20,6 @@ l0_h100: - unittest/_torch/modeling -k "modeling_mixtral" - unittest/_torch/modeling -k "modeling_nemotron" - unittest/_torch/modeling -k "modeling_gemma3" - - unittest/llmapi/test_pytorch_nemo_lora.py - unittest/disaggregated/test_disagg_utils.py - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 486ceb301f5..1c41ca383ed 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -29,8 +29,103 @@ from peft import get_peft_model from transformers import AutoModelForCausalLM +import json +import tarfile +from pathlib import Path + # isort: on +# NeMo LoRA test data +LORA_RANK_CONFIGS = [ + # (lora_rank, max_lora_rank, description) + (8, 8, "rank_8"), + (16, 16, "rank_16"), + (4, 8, "rank_4_max_8"), +] + + +def create_mock_nemo_lora_checkpoint( + lora_dir: Path, + hidden_size: int = 4096, + num_layers: int = 32, + lora_rank: int = 8, + tp_size: int = 1, +) -> Path: + """Create a minimal NeMo LoRA checkpoint for testing. + + This creates a .nemo tarfile with the expected structure: + - model_weights.ckpt containing attn_qkv adapter weights + - model_config.yaml with basic configuration + + Args: + lora_dir: Directory to create the checkpoint in + hidden_size: Model hidden size + num_layers: Number of transformer layers + lora_rank: LoRA rank + tp_size: Tensor parallelism size + + Returns: + Path to the created .nemo file + """ + # Create temporary directory for checkpoint contents + temp_dir = lora_dir / "temp_nemo" + temp_dir.mkdir(exist_ok=True) + + # Create LoRA weights dict + weights_dict = {} + + for layer_idx in range(num_layers): + # NeMo uses this key format for QKV adapters + key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" + + # Create linear_in weights [lora_rank, hidden_size] with small random values + linear_in_key = f"{key_prefix}.linear_in.weight" + weights_dict[linear_in_key] = torch.randn( + lora_rank, hidden_size, dtype=torch.float16) * 0.01 + + # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined + linear_out_key = f"{key_prefix}.linear_out.weight" + weights_dict[linear_out_key] = torch.randn( + 3 * hidden_size, lora_rank, dtype=torch.float16) * 0.01 + + # Save checkpoint + ckpt_path = temp_dir / "model_weights.ckpt" + torch.save(weights_dict, ckpt_path) + + # Create minimal config + config = { + "precision": "fp16", + "trainer": { + "num_nodes": 1, + "devices": tp_size, + }, + "model": { + "hidden_size": hidden_size, + "num_layers": num_layers, + }, + "lora": { + "rank": lora_rank, + "target_modules": ["attn_qkv"], + } + } + + config_path = temp_dir / "model_config.yaml" + # Using JSON for simplicity since YAML parsing isn't critical for the test + with open(config_path, 'w') as f: + json.dump(config, f) + + # Create .nemo tarfile + nemo_path = lora_dir / "test_lora.nemo" + with tarfile.open(nemo_path, 'w') as tar: + tar.add(ckpt_path, arcname="model_weights.ckpt") + tar.add(config_path, arcname="model_config.yaml") + + # Cleanup temp dir + import shutil + shutil.rmtree(temp_dir) + + return nemo_path + @force_ampere def test_tinyllama_logits_processor(): @@ -427,3 +522,163 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: lora_request=lora_requests) assert len(outputs) == 2 + + +# NeMo LoRA tests +@pytest.mark.parametrize("lora_rank,max_lora_rank,description", + LORA_RANK_CONFIGS) +def test_load_torch_nemo_lora_function(lora_rank, max_lora_rank, description): + """Test load_torch_nemo_lora function with different LoRA rank configurations.""" + from tensorrt_llm.lora_manager import load_torch_nemo_lora + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo checkpoint + nemo_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=16, + lora_rank=lora_rank, + ) + + # Test load_torch_nemo_lora + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=max_lora_rank, + ) + + # This should not raise an error + load_torch_nemo_lora(lora_config) + + # Verify configuration was set correctly + assert lora_config.lora_target_modules == [ + "attn_qkv" + ], f"Expected attn_qkv modules for {description}" + assert lora_config.trtllm_modules_to_hf_modules == { + "attn_qkv": "attn_qkv" + }, f"Expected correct module mapping for {description}" + + +def test_nemo_lora_unsupported_modules_validation(): + """Test validation of unsupported modules in NeMo LoRA.""" + from tensorrt_llm.lora_manager import load_torch_nemo_lora + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo checkpoint + nemo_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=16, + lora_rank=8, + ) + + # Test validation: should fail with unsupported modules + invalid_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + lora_target_modules=["attn_qkv", + "mlp_h_to_4h"], # mlp_h_to_4h not supported + max_lora_rank=8, + ) + + with pytest.raises(ValueError, match="NeMo LoRA only supports"): + load_torch_nemo_lora(invalid_config) + + +@force_ampere +def test_tinyllama_nemo_lora(): + """Test end-to-end generation with NeMo LoRA checkpoint.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + # Create a mock NeMo checkpoint for TinyLlama + # TinyLlama has hidden_size=2048, num_layers=22 + nemo_path = create_mock_nemo_lora_checkpoint( + temp_path, + hidden_size=2048, + num_layers=22, + lora_rank=8, + ) + + # Create LoRA config for NeMo checkpoint + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=8, + ) + + # Verify LoRA config is set up correctly + assert lora_config.lora_ckpt_source == "nemo" + assert len(lora_config.lora_dir) == 1 + print(f"✓ Created NeMo LoRA config: {nemo_path}") + + # Use TinyLlama for fast testing + model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + + # Create LLM with NeMo LoRA + llm = LLM( + model=model_path, + lora_config=lora_config, + kv_cache_config=global_kvcache_config, + ) + + try: + # Test prompts + test_prompts = [ + "Hello, how are you?", + "What is the capital of France?", + ] + + # Create LoRA request for the NeMo checkpoint + lora_req = LoRARequest("nemo-task", + 0, + str(nemo_path), + lora_ckpt_source="nemo") + + # Verify LoRA request is configured correctly + assert lora_req.ckpt_source == "nemo" + assert lora_req.path == str(nemo_path) + + # Test with and without LoRA + sampling_params = SamplingParams(max_tokens=20, temperature=0.0) + + # Generate with LoRA + outputs_with_lora = llm.generate(test_prompts, + sampling_params, + lora_request=[lora_req, lora_req]) + + # Generate without LoRA + outputs_without_lora = llm.generate(test_prompts, + sampling_params, + lora_request=[None, None]) + + # Basic validation - outputs should be generated without errors + assert len(outputs_with_lora) == 2 + assert len(outputs_without_lora) == 2 + + # Verify that generation completed successfully (may have minimal output with mock weights) + for i in range(2): + # Check that we got valid completion outputs + assert outputs_with_lora[i].outputs[0] is not None + assert outputs_without_lora[i].outputs[0] is not None + # Check that token_ids are present (even if just EOS token) + assert len(outputs_with_lora[i].outputs[0].token_ids) > 0 + assert len(outputs_without_lora[i].outputs[0].token_ids) > 0 + + print(f"✓ NeMo LoRA generation completed successfully") + print( + f"✓ LoRA output tokens: {[len(out.outputs[0].token_ids) for out in outputs_with_lora]}" + ) + print( + f"✓ Base output tokens: {[len(out.outputs[0].token_ids) for out in outputs_without_lora]}" + ) + + # Test passes if generation completes without errors + # Note: With mock LoRA weights, outputs may be minimal but that's expected + + finally: + llm.shutdown() diff --git a/tests/unittest/llmapi/test_pytorch_nemo_lora.py b/tests/unittest/llmapi/test_pytorch_nemo_lora.py deleted file mode 100644 index 3ff1466faa9..00000000000 --- a/tests/unittest/llmapi/test_pytorch_nemo_lora.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Tests for NeMo LoRA checkpoint loading in PyTorch workflow. - -This file contains fast unit tests that do not require full model initialization. -For integration tests that require full model loading and GPU inference, see: - tests/integration/defs/llmapi/test_llm_pytorch_nemo_lora.py - -Unit tests here should run in seconds, not minutes. -""" - -import json -import tarfile -import tempfile -from pathlib import Path - -import pytest -import torch - -from tensorrt_llm.lora_manager import LoraConfig - - -def create_mock_nemo_lora_checkpoint( - lora_dir: Path, - hidden_size: int = 4096, - num_layers: int = 32, - lora_rank: int = 8, - tp_size: int = 1, -) -> Path: - """Create a minimal NeMo LoRA checkpoint for testing. - - This creates a .nemo tarfile with the expected structure: - - model_weights.ckpt containing attn_qkv adapter weights - - model_config.yaml with basic configuration - - Args: - lora_dir: Directory to create the checkpoint in - hidden_size: Model hidden size - num_layers: Number of transformer layers - lora_rank: LoRA rank - tp_size: Tensor parallelism size - - Returns: - Path to the created .nemo file - """ - # Create temporary directory for checkpoint contents - temp_dir = lora_dir / "temp_nemo" - temp_dir.mkdir(exist_ok=True) - - # Create LoRA weights dict - weights_dict = {} - - for layer_idx in range(num_layers): - # NeMo uses this key format for QKV adapters - key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" - - # Create linear_in weights [lora_rank, hidden_size] - linear_in_key = f"{key_prefix}.linear_in.weight" - weights_dict[linear_in_key] = torch.zeros(lora_rank, - hidden_size, - dtype=torch.float16) - - # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined - linear_out_key = f"{key_prefix}.linear_out.weight" - weights_dict[linear_out_key] = torch.zeros(3 * hidden_size, - lora_rank, - dtype=torch.float16) - - # Save checkpoint - ckpt_path = temp_dir / "model_weights.ckpt" - torch.save(weights_dict, ckpt_path) - - # Create minimal config - config = { - "precision": "fp16", - "trainer": { - "num_nodes": 1, - "devices": tp_size, - }, - "model": { - "hidden_size": hidden_size, - "num_layers": num_layers, - }, - "lora": { - "rank": lora_rank, - "target_modules": ["attn_qkv"], - } - } - - config_path = temp_dir / "model_config.yaml" - # Using JSON for simplicity since YAML parsing isn't critical for the test - with open(config_path, 'w') as f: - json.dump(config, f) - - # Create .nemo tarfile - nemo_path = lora_dir / "test_lora.nemo" - with tarfile.open(nemo_path, 'w') as tar: - tar.add(ckpt_path, arcname="model_weights.ckpt") - tar.add(config_path, arcname="model_config.yaml") - - # Cleanup temp dir - import shutil - shutil.rmtree(temp_dir) - - return nemo_path - - -# Test data for parametrized tests -NEMO_LORA_UNIT_TEST_PARAMS = [ - # (hidden_size, num_layers, lora_rank, description) - (2048, 16, 8, "small_model_rank_8"), - (4096, 32, 16, "large_model_rank_16"), - (1024, 12, 4, "tiny_model_rank_4"), -] - -LORA_RANK_CONFIGS = [ - # (lora_rank, max_lora_rank, description) - (8, 8, "rank_8"), - (16, 16, "rank_16"), - (4, 8, "rank_4_max_8"), -] - - -class TestNemoLoraUnit: - """Unit tests for NeMo LoRA loading without full model initialization.""" - - @pytest.mark.parametrize("hidden_size,num_layers,lora_rank,description", - NEMO_LORA_UNIT_TEST_PARAMS) - def test_nemo_lora_loader_creation(self, hidden_size, num_layers, lora_rank, - description): - """Test NemoLoraLoader creation with different model configurations.""" - from tensorrt_llm.lora_manager import NemoLoraLoader - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo checkpoint - nemo_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=hidden_size, - num_layers=num_layers, - lora_rank=lora_rank, - ) - - # Test NemoLoraLoader directly - loader = NemoLoraLoader([str(nemo_path)]) - assert loader.is_valid, f"NemoLoraLoader failed to validate {nemo_path} for {description}" - assert loader.lora_target_modules == [ - "attn_qkv" - ], f"Expected attn_qkv modules for {description}" - - @pytest.mark.parametrize("lora_rank,max_lora_rank,description", - LORA_RANK_CONFIGS) - def test_load_torch_nemo_lora_function(self, lora_rank, max_lora_rank, - description): - """Test load_torch_nemo_lora function with different LoRA rank configurations.""" - from tensorrt_llm.lora_manager import load_torch_nemo_lora - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo checkpoint - nemo_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, - num_layers=16, - lora_rank=lora_rank, - ) - - # Test load_torch_nemo_lora - lora_config = LoraConfig( - lora_dir=[str(nemo_path)], - lora_ckpt_source="nemo", - max_lora_rank=max_lora_rank, - ) - - # This should not raise an error - load_torch_nemo_lora(lora_config) - - # Verify configuration was set correctly - assert lora_config.lora_target_modules == [ - "attn_qkv" - ], f"Expected attn_qkv modules for {description}" - assert lora_config.trtllm_modules_to_hf_modules == { - "attn_qkv": "attn_qkv" - }, f"Expected correct module mapping for {description}" - - def test_nemo_lora_unsupported_modules_validation(self): - """Test validation of unsupported modules in NeMo LoRA.""" - from tensorrt_llm.lora_manager import load_torch_nemo_lora - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo checkpoint - nemo_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, - num_layers=16, - lora_rank=8, - ) - - # Test validation: should fail with unsupported modules - invalid_config = LoraConfig( - lora_dir=[str(nemo_path)], - lora_ckpt_source="nemo", - lora_target_modules=["attn_qkv", "mlp_h_to_4h" - ], # mlp_h_to_4h not supported - max_lora_rank=8, - ) - - with pytest.raises(ValueError, match="NeMo LoRA only supports"): - load_torch_nemo_lora(invalid_config) - - def test_nemo_lora_empty_target_modules(self): - """Test NeMo LoRA with empty target modules list.""" - from tensorrt_llm.lora_manager import load_torch_nemo_lora - - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo checkpoint - nemo_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, - num_layers=16, - lora_rank=8, - ) - - # Test with empty target modules - should auto-detect - lora_config = LoraConfig( - lora_dir=[str(nemo_path)], - lora_ckpt_source="nemo", - max_lora_rank=8, - ) - - load_torch_nemo_lora(lora_config) - - # Should auto-detect and set attn_qkv - assert lora_config.lora_target_modules == ["attn_qkv"] From d08c6ccc2c7fe3c5070ac9cd086e355ddb5703fc Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 15 Jul 2025 10:35:13 -0700 Subject: [PATCH 06/16] add review suggestions Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/_util.py | 8 +- tensorrt_llm/lora_manager.py | 191 +++++++++++++++++++----- 2 files changed, 154 insertions(+), 45 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index fc19b8f2520..63e737c4920 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -14,7 +14,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import (LoraConfig, get_default_trtllm_modules_to_hf_modules, - load_torch_hf_lora) + load_torch_lora) from tensorrt_llm.mapping import Mapping from ..model_config import ModelConfig @@ -438,11 +438,7 @@ def create_py_executor_instance( if len(lora_config.lora_dir) == 1: # Route to appropriate loader based on checkpoint source - if lora_config.lora_ckpt_source == "nemo": - from tensorrt_llm.lora_manager import load_torch_nemo_lora - load_torch_nemo_lora(lora_config) - else: - load_torch_hf_lora(lora_config) + load_torch_lora(lora_config) else: assert len(lora_config.lora_target_modules ) >= 1, "Expecting at least one lora target module" diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 5c0065d57c2..5d46ebe4534 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -23,8 +23,21 @@ from .runtime import ModelConfig -def get_all_nemo_lora_weights(lora_weights): - layer_weights = defaultdict(dict) +def get_all_nemo_lora_weights( + lora_weights: Dict[str, torch.Tensor], +) -> Dict[int, Dict[str, torch.Tensor]]: + """Extract and organize NeMo LoRA weights by layer and direction. + + Args: + lora_weights: Dictionary mapping weight keys to tensors from NeMo checkpoint + + Returns: + Dictionary mapping layer_idx -> {direction -> tensor} where direction is 'in' or 'out' + + Raises: + KeyError: If unsupported keys are found or layer extraction fails + """ + layer_weights: Dict[int, Dict[str, torch.Tensor]] = defaultdict(dict) adapter_key = "self_attention.adapter_layer.lora_kqv_adapter" layer_pattern = re.compile(r".*\.layers\.(\d+)\..*") for key, weights in lora_weights.items(): @@ -53,7 +66,25 @@ def get_all_nemo_lora_weights(lora_weights): ) -def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None): +def iterate_hf_lora( + iter_fn, lora_weights: Dict[str, torch.Tensor], hf_modules: set, component: Optional[str] = None +): + """Iterate over HuggingFace LoRA weights and call iterator function for each weight. + + Args: + iter_fn: Function to call for each weight with signature + (layer_idx, hf_module, expert_idx, inout_or_mag, weights) + lora_weights: Dictionary mapping weight keys to tensors from HF checkpoint + hf_modules: Set of supported HF module names + component: Optional component name to filter by (e.g., 'decoder') + + Returns: + Nested dictionary structure organizing the weights + + Raises: + KeyError: If unsupported keys are found + AssertionError: If HF module is not in supported list + """ all_weights = defaultdict(lambda: defaultdict(dict)) pattern = HF_LORA_PATTERN for key, weights in lora_weights.items(): @@ -97,7 +128,20 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None): return all_weights -def get_all_hf_lora_weights(lora_weights, hf_modules, component=None): +def get_all_hf_lora_weights( + lora_weights: Dict[str, torch.Tensor], hf_modules: set, component: Optional[str] = None +): + """Extract and organize all HuggingFace LoRA weights by layer and module. + + Args: + lora_weights: Dictionary mapping weight keys to tensors from HF checkpoint + hf_modules: Set of supported HF module names + component: Optional component name to filter by (e.g., 'decoder') + + Returns: + Nested dictionary organizing weights by layer, module, and potentially expert + """ + def iter_fn(layer_idx, hf_module, expert_idx, inout, weights): if expert_idx is None: all_weights[layer_idx][hf_module][inout] = weights @@ -119,8 +163,19 @@ def iter_fn(layer_idx, hf_module, expert_idx, inout, weights): return hf_target_modules -def invert_module_mapping(trtllm_modules_to_hf_modules): - hf_modules_to_trtllm_modules = {} +def invert_module_mapping( + trtllm_modules_to_hf_modules: Dict[str, Union[str, List[str]]], +) -> Dict[str, str]: + """Invert module mapping from TensorRT-LLM -> HF to HF -> TensorRT-LLM. + + Args: + trtllm_modules_to_hf_modules: Mapping from TensorRT-LLM module names to HF module names + (values can be strings or lists of strings) + + Returns: + Dictionary mapping HF module names to TensorRT-LLM module names + """ + hf_modules_to_trtllm_modules: Dict[str, str] = {} for k, hf_modules in trtllm_modules_to_hf_modules.items(): if isinstance(hf_modules, list): for hf_module in hf_modules: @@ -219,39 +274,48 @@ def get_target_modules(self, trtllm_modules_to_hf_modules): return list(lora_target_modules) -@lru_cache(maxsize=32) -def _find_nemo_files_cached(lora_dirs_tuple): - # Helper for caching: lora_dirs must be a tuple of strings - nemo_files = [] +@lru_cache(maxsize=128) +def _find_nemo_files_single_path(lora_path: str) -> List[str]: + """Find .nemo files from a single path (file or directory). - for lora_path in lora_dirs_tuple: - path = Path(lora_path) - if not path.exists(): - raise ValueError(f"{path} does not exist") + This function is cached per individual path to maximize cache efficiency + when the same paths appear in different collections. - if path.is_file(): - if path.suffix == ".nemo": - nemo_files.append(str(path)) - else: - raise ValueError(f"{path} is not a .nemo file") - elif path.is_dir(): - nemo_files_in_dir = list(path.glob("*.nemo")) - if not nemo_files_in_dir: - raise ValueError(f"No .nemo files found in directory {path}") - nemo_files.extend([str(f) for f in nemo_files_in_dir]) - else: - raise ValueError(f"{path} is neither a file nor a directory") + Args: + lora_path: A single path that can be either: + - Direct path to a .nemo file + - Directory containing .nemo files (will auto-detect *.nemo) - if not nemo_files: - raise ValueError("No .nemo files found in the provided paths") + Returns: + List[str]: List of paths to .nemo files found in this single path + + Raises: + ValueError: If path doesn't exist, no .nemo files found, or invalid file type + """ + path = Path(lora_path) + if not path.exists(): + raise ValueError(f"{path} does not exist") - return nemo_files + if path.is_file(): + if path.suffix == ".nemo": + return [str(path)] + else: + raise ValueError(f"{path} is not a .nemo file") + elif path.is_dir(): + nemo_files_in_dir = list(path.glob("*.nemo")) + if not nemo_files_in_dir: + raise ValueError(f"No .nemo files found in directory {path}") + return [str(f) for f in nemo_files_in_dir] + else: + raise ValueError(f"{path} is neither a file nor a directory") def find_nemo_files(lora_dirs: List[str]) -> List[str]: """Find all .nemo files from a list of directories or file paths. - This function is optimized for repeated calls by using an internal LRU cache. + This function is optimized for repeated calls at generation time by using an internal LRU cache + on individual paths, which maximizes cache efficiency when the same paths + appear in different collections. Args: lora_dirs: List of paths that can be either: @@ -262,11 +326,21 @@ def find_nemo_files(lora_dirs: List[str]) -> List[str]: List[str]: List of paths to .nemo files Raises: - ValueError: If path doesn't exist, no .nemo files found, or invalid file type + ValueError: If a path doesn't exist, no .nemo files are found in a directory + path, or a file path is of invalid file type """ if len(lora_dirs) == 0: return [] - return _find_nemo_files_cached(tuple(lora_dirs)) + + all_nemo_files: List[str] = [] + for lora_path in lora_dirs: + nemo_files_for_path = _find_nemo_files_single_path(lora_path) + all_nemo_files.extend(nemo_files_for_path) + + if not all_nemo_files: + raise ValueError("No .nemo files found in the provided paths") + + return all_nemo_files class NemoLoraLoader: @@ -296,8 +370,15 @@ def __init__(self, lora_dirs: List[str]): # Hardcoded since LoraManager only supports this case now self.lora_target_modules = ["attn_qkv"] - def get_target_modules(self, trtllm_modules_to_hf_modules): - """Get target modules for NeMo LoRA.""" + def get_target_modules(self): + """Get target modules for NeMo LoRA. + + Unlike the HF loader, this method does not accept trtllm_modules_to_hf_modules + as an argument since the module mapping is hardcoded for NeMo LoRA support. + + Returns: + List[str]: List of target module names supported by NeMo LoRA + """ return self.lora_target_modules @@ -375,9 +456,7 @@ def load_torch_nemo_lora(lora_config: LoraConfig): raise ValueError(f"Failed to load NeMo LoRA from {lora_config.lora_dir}") if len(lora_config.lora_target_modules) == 0: - lora_config.lora_target_modules = lora_loader.get_target_modules( - lora_config.trtllm_modules_to_hf_modules - ) + lora_config.lora_target_modules = lora_loader.get_target_modules() if len(lora_config.lora_target_modules) == 0: raise ValueError( @@ -403,6 +482,29 @@ def load_torch_nemo_lora(lora_config: LoraConfig): # just sets up the configuration. +def load_torch_lora(lora_config: LoraConfig): + """Load LoRA checkpoint for PyTorch workflow. + + This function routes to the appropriate loader based on lora_ckpt_source. + It centralizes the routing logic that was previously scattered in _util.py. + + Args: + lora_config: LoRA configuration with lora_ckpt_source set to "hf" or "nemo" + + Raises: + ValueError: If lora_ckpt_source is not supported + """ + if lora_config.lora_ckpt_source == "nemo": + load_torch_nemo_lora(lora_config) + elif lora_config.lora_ckpt_source == "hf": + load_torch_hf_lora(lora_config) + else: + raise ValueError( + f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}. " + f"Supported sources: 'hf', 'nemo'" + ) + + def load_hf_lora( model, lora_config: LoraConfig, @@ -504,7 +606,18 @@ def use_lora( raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") -def unpack_nemo_weights(nemo_archive_path): +def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]: + """Unpack model config and weights from a NeMo .nemo archive file. + + Args: + nemo_archive_path: Path to the .nemo archive file + + Returns: + Tuple of (model_config_dict, model_weights_dict) + + Raises: + Exception: If required files cannot be extracted from the archive + """ with tarfile.open(nemo_archive_path) as tar: try: model_weights_file = tar.extractfile("model_weights.ckpt") From cf176b63dcc1b078d7415f61ef28e0d43243be88 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:36:21 -0700 Subject: [PATCH 07/16] add review suggestions Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 21 +++-- tests/unittest/llmapi/test_llm_pytorch.py | 102 +++++++++++----------- 2 files changed, 59 insertions(+), 64 deletions(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 5d46ebe4534..9bad7c1386c 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -440,13 +440,20 @@ def load_torch_nemo_lora(lora_config: LoraConfig): """Load NeMo LoRA checkpoint for PyTorch workflow. This is a PyTorch-specific loader for NeMo LoRA checkpoints, similar to - load_torch_hf_lora but handling NeMo checkpoint format. + load_torch_hf_lora but handling NeMo checkpoint format. NeMo uses a combined + "attn_qkv" module rather than separate Q, K, V modules, so no missing QKV + module handling is needed. + + Note: This function only sets up the configuration. For PyTorch workflow, + the actual weight loading happens later via LoraManager when requests are + made with LoRA UIDs. Args: lora_config: LoRA configuration with lora_ckpt_source="nemo" + + Raises: + ValueError: If NeMo LoRA directory is invalid or unsupported modules are specified """ - # For NeMo, we need to set up module mappings differently - # NeMo uses "attn_qkv" as a combined module lora_config.trtllm_modules_to_hf_modules = {"attn_qkv": "attn_qkv"} assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir" @@ -464,7 +471,6 @@ def load_torch_nemo_lora(lora_config: LoraConfig): "Please specify lora_target_modules or provide lora_dir to infer lora_target_modules." ) - # Validate that NeMo LoRA only supports attn_qkv supported_modules = {"attn_qkv"} unsupported_modules = set(lora_config.lora_target_modules) - supported_modules if unsupported_modules: @@ -474,13 +480,6 @@ def load_torch_nemo_lora(lora_config: LoraConfig): f"NeMo LoRA does not support embedding, lm_head, or MLP adapters." ) - # NeMo only supports attn_qkv currently, no need for missing QKV module handling - # as it's already combined - - # Note: For PyTorch workflow, the actual weight loading happens later - # via LoraManager when requests are made with LoRA UIDs. This function - # just sets up the configuration. - def load_torch_lora(lora_config: LoraConfig): """Load LoRA checkpoint for PyTorch workflow. diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 1c41ca383ed..5a58a376b9e 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -67,62 +67,58 @@ def create_mock_nemo_lora_checkpoint( Returns: Path to the created .nemo file """ - # Create temporary directory for checkpoint contents - temp_dir = lora_dir / "temp_nemo" - temp_dir.mkdir(exist_ok=True) - - # Create LoRA weights dict - weights_dict = {} - - for layer_idx in range(num_layers): - # NeMo uses this key format for QKV adapters - key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" - - # Create linear_in weights [lora_rank, hidden_size] with small random values - linear_in_key = f"{key_prefix}.linear_in.weight" - weights_dict[linear_in_key] = torch.randn( - lora_rank, hidden_size, dtype=torch.float16) * 0.01 - - # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined - linear_out_key = f"{key_prefix}.linear_out.weight" - weights_dict[linear_out_key] = torch.randn( - 3 * hidden_size, lora_rank, dtype=torch.float16) * 0.01 - - # Save checkpoint - ckpt_path = temp_dir / "model_weights.ckpt" - torch.save(weights_dict, ckpt_path) - - # Create minimal config - config = { - "precision": "fp16", - "trainer": { - "num_nodes": 1, - "devices": tp_size, - }, - "model": { - "hidden_size": hidden_size, - "num_layers": num_layers, - }, - "lora": { - "rank": lora_rank, - "target_modules": ["attn_qkv"], - } - } + nemo_path = lora_dir / "test_lora.nemo" - config_path = temp_dir / "model_config.yaml" - # Using JSON for simplicity since YAML parsing isn't critical for the test - with open(config_path, 'w') as f: - json.dump(config, f) + # Use temporary directory context manager for safe cleanup + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) - # Create .nemo tarfile - nemo_path = lora_dir / "test_lora.nemo" - with tarfile.open(nemo_path, 'w') as tar: - tar.add(ckpt_path, arcname="model_weights.ckpt") - tar.add(config_path, arcname="model_config.yaml") + # Create LoRA weights dict + weights_dict = {} + + for layer_idx in range(num_layers): + # NeMo uses this key format for QKV adapters + key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" + + # Create linear_in weights [lora_rank, hidden_size] with small random values + linear_in_key = f"{key_prefix}.linear_in.weight" + weights_dict[linear_in_key] = torch.randn( + lora_rank, hidden_size, dtype=torch.float16) * 0.01 + + # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined + linear_out_key = f"{key_prefix}.linear_out.weight" + weights_dict[linear_out_key] = torch.randn( + 3 * hidden_size, lora_rank, dtype=torch.float16) * 0.01 + + ckpt_path = temp_dir / "model_weights.ckpt" + torch.save(weights_dict, ckpt_path) + + # Create minimal config + config = { + "precision": "fp16", + "trainer": { + "num_nodes": 1, + "devices": tp_size, + }, + "model": { + "hidden_size": hidden_size, + "num_layers": num_layers, + }, + "lora": { + "rank": lora_rank, + "target_modules": ["attn_qkv"], + } + } + + config_path = temp_dir / "model_config.yaml" + # Using JSON for simplicity since YAML parsing isn't critical for the test + with open(config_path, 'w') as f: + json.dump(config, f) - # Cleanup temp dir - import shutil - shutil.rmtree(temp_dir) + # Create .nemo tarfile + with tarfile.open(nemo_path, 'w') as tar: + tar.add(ckpt_path, arcname="model_weights.ckpt") + tar.add(config_path, arcname="model_config.yaml") return nemo_path From 7f71d4d74f974278380bdf117c610cd2b3297467 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Wed, 16 Jul 2025 12:10:11 -0700 Subject: [PATCH 08/16] enable gqa Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 25 +++- tensorrt_llm/_torch/pyexecutor/_util.py | 2 +- tests/unittest/llmapi/test_llm_pytorch.py | 157 ++++++++++++---------- 3 files changed, 105 insertions(+), 79 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 3de3edd3a9b..8d63520c584 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -297,6 +297,27 @@ def get_bindings_model_config(self, num_heads = self.pretrained_config.num_attention_heads // ( self.mapping.tp_size * self.mapping.cp_size) + + # Handle both uniform and per-layer KV heads + if hasattr(self.pretrained_config, 'num_kv_heads_per_layer'): + # For models with per-layer KV heads, use the first layer's value + num_kv_heads_raw = self.pretrained_config.num_kv_heads_per_layer[0] + # TRT-LLM LoRA requires uniform KV heads across layers + if self.lora_config is not None and not all( + kv == num_kv_heads_raw + for kv in self.pretrained_config.num_kv_heads_per_layer): + raise ValueError( + f"TRT-LLM LoRA requires uniform KV heads across layers, got: {self.pretrained_config.num_kv_heads_per_layer}" + ) + else: + # For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads + num_kv_heads_raw = getattr( + self.pretrained_config, 'num_key_value_heads', + getattr(self.pretrained_config, 'num_query_groups', + self.pretrained_config.num_attention_heads)) + + num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size * + self.mapping.cp_size) hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size model_config_cpp = ModelConfigCpp( @@ -317,10 +338,6 @@ def get_bindings_model_config(self, else: model_config_cpp.tokens_per_block = tokens_per_block - # For kv cache size calculation: set num_kv_heads - num_kv_heads = getattr( - self.pretrained_config, "num_key_value_heads", - num_heads) // (self.mapping.tp_size * self.mapping.cp_size) model_config_cpp.set_num_kv_heads(num_kv_heads) mlp_hidden_size = None diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 63e737c4920..54d2cdcf979 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -456,7 +456,7 @@ def create_py_executor_instance( hidden_size=model_binding_config.hidden_size, mlp_hidden_size=model_binding_config.mlp_hidden_size, num_attention_heads=model_binding_config.num_heads, - num_kv_attention_heads=model_binding_config.num_heads, + num_kv_attention_heads=model_binding_config.num_kv_heads(0), attention_head_size=model_binding_config.head_size, tp_size=mapping.tp_size, num_experts=num_experts) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 5a58a376b9e..3175ad5e612 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -45,11 +45,13 @@ def create_mock_nemo_lora_checkpoint( - lora_dir: Path, - hidden_size: int = 4096, - num_layers: int = 32, - lora_rank: int = 8, - tp_size: int = 1, + lora_dir: Path, + hidden_size: int = 4096, + num_layers: int = 32, + lora_rank: int = 8, + tp_size: int = 1, + num_attention_heads: int = 32, + num_kv_heads: int = None, # If None, defaults to num_attention_heads ) -> Path: """Create a minimal NeMo LoRA checkpoint for testing. @@ -63,10 +65,16 @@ def create_mock_nemo_lora_checkpoint( num_layers: Number of transformer layers lora_rank: LoRA rank tp_size: Tensor parallelism size + num_attention_heads: Number of query attention heads + num_kv_heads: Number of key/value heads (for GQA). If None, equals num_attention_heads Returns: Path to the created .nemo file """ + # Default to standard MHA if not specified + if num_kv_heads is None: + num_kv_heads = num_attention_heads + nemo_path = lora_dir / "test_lora.nemo" # Use temporary directory context manager for safe cleanup @@ -76,6 +84,14 @@ def create_mock_nemo_lora_checkpoint( # Create LoRA weights dict weights_dict = {} + # Calculate head dimensions + head_dim = hidden_size // num_attention_heads + kv_hidden_size = head_dim * num_kv_heads + + # Calculate QKV output dimensions for NeMo's fused format + # NeMo fuses QKV: Q(hidden_size) + K(kv_hidden_size) + V(kv_hidden_size) + qkv_output_dim = hidden_size + 2 * kv_hidden_size + for layer_idx in range(num_layers): # NeMo uses this key format for QKV adapters key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" @@ -85,15 +101,16 @@ def create_mock_nemo_lora_checkpoint( weights_dict[linear_in_key] = torch.randn( lora_rank, hidden_size, dtype=torch.float16) * 0.01 - # Create linear_out weights [3 * hidden_size, lora_rank] for QKV combined + # Create linear_out weights [qkv_output_dim, lora_rank] for fused QKV + # This is the key difference for GQA - the output dimension changes linear_out_key = f"{key_prefix}.linear_out.weight" weights_dict[linear_out_key] = torch.randn( - 3 * hidden_size, lora_rank, dtype=torch.float16) * 0.01 + qkv_output_dim, lora_rank, dtype=torch.float16) * 0.01 ckpt_path = temp_dir / "model_weights.ckpt" torch.save(weights_dict, ckpt_path) - # Create minimal config + # Create minimal config with GQA support config = { "precision": "fp16", "trainer": { @@ -103,6 +120,8 @@ def create_mock_nemo_lora_checkpoint( "model": { "hidden_size": hidden_size, "num_layers": num_layers, + "num_attention_heads": num_attention_heads, + "num_query_groups": num_kv_heads, # This is the key for GQA }, "lora": { "rank": lora_rank, @@ -586,95 +605,85 @@ def test_nemo_lora_unsupported_modules_validation(): @force_ampere -def test_tinyllama_nemo_lora(): - """Test end-to-end generation with NeMo LoRA checkpoint.""" +def test_gqa_nemo_lora(): + """Test NeMo LoRA with GQA using TinyLlama's exact dimensions. + + TinyLlama-1.1B-Chat-v1.0 specs (verified from config.json): + - hidden_size: 2048 + - num_hidden_layers: 22 + - num_attention_heads: 32 (Query heads) + - num_key_value_heads: 4 (Key/Value heads) + - This gives 32/4 = 8 query heads per KV group (GQA) + """ with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) - # Create a mock NeMo checkpoint for TinyLlama - # TinyLlama has hidden_size=2048, num_layers=22 + # TinyLlama's exact GQA configuration + hidden_size = 2048 + num_layers = 22 + num_q_heads = 32 # Query attention heads + num_kv_heads = 4 # Key/Value heads (GQA) + lora_rank = 8 + + print( + f"\n✓ Testing TinyLlama GQA config: Q_heads={num_q_heads}, KV_heads={num_kv_heads}, rank={lora_rank}" + ) + + # Create a mock NeMo checkpoint with TinyLlama's exact GQA configuration nemo_path = create_mock_nemo_lora_checkpoint( temp_path, - hidden_size=2048, - num_layers=22, - lora_rank=8, + hidden_size=hidden_size, + num_layers=num_layers, + lora_rank=lora_rank, + num_attention_heads=num_q_heads, + num_kv_heads=num_kv_heads, ) # Create LoRA config for NeMo checkpoint lora_config = LoraConfig( lora_dir=[str(nemo_path)], lora_ckpt_source="nemo", - max_lora_rank=8, + max_lora_rank=lora_rank, ) - # Verify LoRA config is set up correctly - assert lora_config.lora_ckpt_source == "nemo" - assert len(lora_config.lora_dir) == 1 - print(f"✓ Created NeMo LoRA config: {nemo_path}") - - # Use TinyLlama for fast testing + # Use TinyLlama model - dimensions now match exactly model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") - # Create LLM with NeMo LoRA - llm = LLM( - model=model_path, - lora_config=lora_config, - kv_cache_config=global_kvcache_config, - ) - try: - # Test prompts - test_prompts = [ - "Hello, how are you?", - "What is the capital of France?", - ] + # Create LLM with NeMo LoRA + llm = LLM( + model=model_path, + lora_config=lora_config, + kv_cache_config=global_kvcache_config, + ) + + # Test prompt + test_prompts = ["Test TinyLlama GQA with NeMo LoRA"] # Create LoRA request for the NeMo checkpoint - lora_req = LoRARequest("nemo-task", + lora_req = LoRARequest("tinyllama-gqa-test", 0, str(nemo_path), lora_ckpt_source="nemo") - # Verify LoRA request is configured correctly - assert lora_req.ckpt_source == "nemo" - assert lora_req.path == str(nemo_path) - - # Test with and without LoRA - sampling_params = SamplingParams(max_tokens=20, temperature=0.0) - # Generate with LoRA - outputs_with_lora = llm.generate(test_prompts, - sampling_params, - lora_request=[lora_req, lora_req]) - - # Generate without LoRA - outputs_without_lora = llm.generate(test_prompts, - sampling_params, - lora_request=[None, None]) - - # Basic validation - outputs should be generated without errors - assert len(outputs_with_lora) == 2 - assert len(outputs_without_lora) == 2 - - # Verify that generation completed successfully (may have minimal output with mock weights) - for i in range(2): - # Check that we got valid completion outputs - assert outputs_with_lora[i].outputs[0] is not None - assert outputs_without_lora[i].outputs[0] is not None - # Check that token_ids are present (even if just EOS token) - assert len(outputs_with_lora[i].outputs[0].token_ids) > 0 - assert len(outputs_without_lora[i].outputs[0].token_ids) > 0 - - print(f"✓ NeMo LoRA generation completed successfully") - print( - f"✓ LoRA output tokens: {[len(out.outputs[0].token_ids) for out in outputs_with_lora]}" - ) - print( - f"✓ Base output tokens: {[len(out.outputs[0].token_ids) for out in outputs_without_lora]}" - ) + sampling_params = SamplingParams(max_tokens=10, temperature=0.0) + outputs = llm.generate(test_prompts, + sampling_params, + lora_request=[lora_req]) - # Test passes if generation completes without errors - # Note: With mock LoRA weights, outputs may be minimal but that's expected + # Basic validation + assert len(outputs) == 1 + assert outputs[0].outputs[0] is not None + assert len(outputs[0].outputs[0].token_ids) > 0 + print(f" ✓ TinyLlama GQA with NeMo LoRA passed successfully!") + + except Exception as e: + # Any error now indicates a real problem since dimensions match + pytest.fail(f"TinyLlama GQA test failed: {e}") finally: - llm.shutdown() + if 'llm' in locals(): + llm.shutdown() + + print("✓ TinyLlama GQA NeMo LoRA test completed successfully") From 352ae268b4cc327453672f3a939b72b14eb01629 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:04:32 -0700 Subject: [PATCH 09/16] support gqa and initial plubminb for vgqa Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 63 +++++++++++++++-------- tensorrt_llm/_torch/pyexecutor/_util.py | 15 +++++- tests/unittest/llmapi/test_llm_pytorch.py | 23 +-------- 3 files changed, 58 insertions(+), 43 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 8d63520c584..d47cc9f831b 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -299,25 +299,41 @@ def get_bindings_model_config(self, self.mapping.tp_size * self.mapping.cp_size) # Handle both uniform and per-layer KV heads - if hasattr(self.pretrained_config, 'num_kv_heads_per_layer'): - # For models with per-layer KV heads, use the first layer's value - num_kv_heads_raw = self.pretrained_config.num_kv_heads_per_layer[0] - # TRT-LLM LoRA requires uniform KV heads across layers - if self.lora_config is not None and not all( - kv == num_kv_heads_raw - for kv in self.pretrained_config.num_kv_heads_per_layer): - raise ValueError( - f"TRT-LLM LoRA requires uniform KV heads across layers, got: {self.pretrained_config.num_kv_heads_per_layer}" - ) + if hasattr( + self.pretrained_config, 'num_kv_heads_per_layer' + ) and self.pretrained_config.num_kv_heads_per_layer is not None: + # For models with per-layer KV heads, like nemotron-nas + num_kv_heads_per_layer_raw = self.pretrained_config.num_kv_heads_per_layer + # Apply TP/CP scaling to each layer + num_kv_heads_per_layer = [ + kv_heads // (self.mapping.tp_size * self.mapping.cp_size) + for kv_heads in num_kv_heads_per_layer_raw + ] + _use_per_layer_kv_heads = True else: - # For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads - num_kv_heads_raw = getattr( - self.pretrained_config, 'num_key_value_heads', - getattr(self.pretrained_config, 'num_query_groups', - self.pretrained_config.num_attention_heads)) - - num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size * - self.mapping.cp_size) + # Check if num_key_value_heads is a list (per-layer) or scalar (uniform) + num_kv_heads_raw = getattr(self.pretrained_config, + 'num_key_value_heads', None) + + if num_kv_heads_raw is not None and isinstance( + num_kv_heads_raw, list): + # num_key_value_heads is a list - treat as per-layer KV heads + num_kv_heads_per_layer_raw = num_kv_heads_raw + # Apply TP/CP scaling to each layer + num_kv_heads_per_layer = [ + kv_heads // (self.mapping.tp_size * self.mapping.cp_size) + for kv_heads in num_kv_heads_per_layer_raw + ] + _use_per_layer_kv_heads = True + else: + # num_key_value_heads is scalar or None - treat as uniform KV heads + if num_kv_heads_raw is None: + num_kv_heads_raw = self.pretrained_config.num_attention_heads + + num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size * + self.mapping.cp_size) + _use_per_layer_kv_heads = False + hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size model_config_cpp = ModelConfigCpp( @@ -338,7 +354,10 @@ def get_bindings_model_config(self, else: model_config_cpp.tokens_per_block = tokens_per_block - model_config_cpp.set_num_kv_heads(num_kv_heads) + if _use_per_layer_kv_heads: + model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer + else: + model_config_cpp.set_num_kv_heads(num_kv_heads) mlp_hidden_size = None if self.pretrained_config.intermediate_size is not None: @@ -388,8 +407,10 @@ def _infer_nemotron_ffn_mult(self): # Nemotron-NAS has variable ffn_mult for each layer, we need to find the maximum # so that we don't set a too small mlp_hidden_size. This solution leads to a memory # consumption that is higher than required. - biggest_ffn_mult = max( - [x.ffn.ffn_mult for x in self.pretrained_config.block_configs]) + biggest_ffn_mult = max([ + (x.ffn.ffn_mult if x.ffn.ffn_mult is not None else 0) + for x in self.pretrained_config.block_configs + ]) from tensorrt_llm._torch.models.modeling_nemotron_nas import \ _ffn_mult_to_intermediate_size diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 54d2cdcf979..7a772182345 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -451,12 +451,25 @@ def create_py_executor_instance( num_experts = _try_infer_num_experts(model_engine.model.model_config) + num_attn_layers = model_binding_config.num_attention_layers() + per_layer_kv_heads = [ + model_binding_config.num_kv_heads(i) for i in range(num_attn_layers) + ] + num_kv_attention_heads = per_layer_kv_heads[0] if len( + set(per_layer_kv_heads)) == 1 else max(per_layer_kv_heads) + if len(set(per_layer_kv_heads)) > 1: + # NOTE: This code-path is currently untested and not validated. Can fail! + # This support is tracked in TRTLLM-6561 + logger.warning( + f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. This code-path is currently untested and not validated. May fail!" + ) + lora_modules = LoraModule.create_lora_modules( lora_module_names=lora_config.lora_target_modules, hidden_size=model_binding_config.hidden_size, mlp_hidden_size=model_binding_config.mlp_hidden_size, num_attention_heads=model_binding_config.num_heads, - num_kv_attention_heads=model_binding_config.num_kv_heads(0), + num_kv_attention_heads=num_kv_attention_heads, attention_head_size=model_binding_config.head_size, tp_size=mapping.tp_size, num_experts=num_experts) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 3175ad5e612..a62fcdd95b4 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -606,14 +606,8 @@ def test_nemo_lora_unsupported_modules_validation(): @force_ampere def test_gqa_nemo_lora(): - """Test NeMo LoRA with GQA using TinyLlama's exact dimensions. - - TinyLlama-1.1B-Chat-v1.0 specs (verified from config.json): - - hidden_size: 2048 - - num_hidden_layers: 22 - - num_attention_heads: 32 (Query heads) - - num_key_value_heads: 4 (Key/Value heads) - - This gives 32/4 = 8 query heads per KV group (GQA) + """Test NeMo LoRA with GQA using TinyLlama. + """ with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -625,11 +619,6 @@ def test_gqa_nemo_lora(): num_kv_heads = 4 # Key/Value heads (GQA) lora_rank = 8 - print( - f"\n✓ Testing TinyLlama GQA config: Q_heads={num_q_heads}, KV_heads={num_kv_heads}, rank={lora_rank}" - ) - - # Create a mock NeMo checkpoint with TinyLlama's exact GQA configuration nemo_path = create_mock_nemo_lora_checkpoint( temp_path, hidden_size=hidden_size, @@ -639,34 +628,28 @@ def test_gqa_nemo_lora(): num_kv_heads=num_kv_heads, ) - # Create LoRA config for NeMo checkpoint lora_config = LoraConfig( lora_dir=[str(nemo_path)], lora_ckpt_source="nemo", max_lora_rank=lora_rank, ) - # Use TinyLlama model - dimensions now match exactly model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") try: - # Create LLM with NeMo LoRA llm = LLM( model=model_path, lora_config=lora_config, kv_cache_config=global_kvcache_config, ) - # Test prompt test_prompts = ["Test TinyLlama GQA with NeMo LoRA"] - # Create LoRA request for the NeMo checkpoint lora_req = LoRARequest("tinyllama-gqa-test", 0, str(nemo_path), lora_ckpt_source="nemo") - # Generate with LoRA sampling_params = SamplingParams(max_tokens=10, temperature=0.0) outputs = llm.generate(test_prompts, sampling_params, @@ -685,5 +668,3 @@ def test_gqa_nemo_lora(): finally: if 'llm' in locals(): llm.shutdown() - - print("✓ TinyLlama GQA NeMo LoRA test completed successfully") From 3e3c3745a32a8c290a21b9b34cb59c8929b74661 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:24:33 -0700 Subject: [PATCH 10/16] cosmetic Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 1 - tests/unittest/llmapi/test_llm_pytorch.py | 203 ++++++++++------------ 2 files changed, 94 insertions(+), 110 deletions(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 9bad7c1386c..a73e6394c59 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -485,7 +485,6 @@ def load_torch_lora(lora_config: LoraConfig): """Load LoRA checkpoint for PyTorch workflow. This function routes to the appropriate loader based on lora_ckpt_source. - It centralizes the routing logic that was previously scattered in _util.py. Args: lora_config: LoRA configuration with lora_ckpt_source set to "hf" or "nemo" diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index a62fcdd95b4..18fea69977f 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -77,23 +77,17 @@ def create_mock_nemo_lora_checkpoint( nemo_path = lora_dir / "test_lora.nemo" - # Use temporary directory context manager for safe cleanup with tempfile.TemporaryDirectory() as temp_dir_str: temp_dir = Path(temp_dir_str) - # Create LoRA weights dict weights_dict = {} - # Calculate head dimensions head_dim = hidden_size // num_attention_heads kv_hidden_size = head_dim * num_kv_heads - # Calculate QKV output dimensions for NeMo's fused format - # NeMo fuses QKV: Q(hidden_size) + K(kv_hidden_size) + V(kv_hidden_size) qkv_output_dim = hidden_size + 2 * kv_hidden_size for layer_idx in range(num_layers): - # NeMo uses this key format for QKV adapters key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" # Create linear_in weights [lora_rank, hidden_size] with small random values @@ -110,7 +104,6 @@ def create_mock_nemo_lora_checkpoint( ckpt_path = temp_dir / "model_weights.ckpt" torch.save(weights_dict, ckpt_path) - # Create minimal config with GQA support config = { "precision": "fp16", "trainer": { @@ -134,7 +127,6 @@ def create_mock_nemo_lora_checkpoint( with open(config_path, 'w') as f: json.dump(config, f) - # Create .nemo tarfile with tarfile.open(nemo_path, 'w') as tar: tar.add(ckpt_path, arcname="model_weights.ckpt") tar.add(config_path, arcname="model_config.yaml") @@ -539,132 +531,125 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: assert len(outputs) == 2 -# NeMo LoRA tests -@pytest.mark.parametrize("lora_rank,max_lora_rank,description", - LORA_RANK_CONFIGS) -def test_load_torch_nemo_lora_function(lora_rank, max_lora_rank, description): +@pytest.mark.parametrize( + "lora_rank,max_lora_rank,description", + [ + # (lora_rank, max_lora_rank, description) + (8, 8, "rank_8"), + (16, 16, "rank_16"), + (4, 8, "rank_4_max_8"), + ]) +def test_load_torch_nemo_lora_function(tmp_path, lora_rank, max_lora_rank, + description): """Test load_torch_nemo_lora function with different LoRA rank configurations.""" from tensorrt_llm.lora_manager import load_torch_nemo_lora - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo checkpoint - nemo_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, - num_layers=16, - lora_rank=lora_rank, - ) + nemo_path = create_mock_nemo_lora_checkpoint( + tmp_path, + hidden_size=2048, + num_layers=16, + lora_rank=lora_rank, + ) - # Test load_torch_nemo_lora - lora_config = LoraConfig( - lora_dir=[str(nemo_path)], - lora_ckpt_source="nemo", - max_lora_rank=max_lora_rank, - ) + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=max_lora_rank, + ) - # This should not raise an error - load_torch_nemo_lora(lora_config) + # This should not raise an error + load_torch_nemo_lora(lora_config) - # Verify configuration was set correctly - assert lora_config.lora_target_modules == [ - "attn_qkv" - ], f"Expected attn_qkv modules for {description}" - assert lora_config.trtllm_modules_to_hf_modules == { - "attn_qkv": "attn_qkv" - }, f"Expected correct module mapping for {description}" + assert lora_config.lora_target_modules == [ + "attn_qkv" + ], f"Expected attn_qkv modules for {description}" + assert lora_config.trtllm_modules_to_hf_modules == { + "attn_qkv": "attn_qkv" + }, f"Expected correct module mapping for {description}" -def test_nemo_lora_unsupported_modules_validation(): +def test_nemo_lora_unsupported_modules_validation(tmp_path): """Test validation of unsupported modules in NeMo LoRA.""" from tensorrt_llm.lora_manager import load_torch_nemo_lora - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Create a mock NeMo checkpoint - nemo_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=2048, - num_layers=16, - lora_rank=8, - ) + nemo_path = create_mock_nemo_lora_checkpoint( + tmp_path, + hidden_size=2048, + num_layers=16, + lora_rank=8, + ) - # Test validation: should fail with unsupported modules - invalid_config = LoraConfig( - lora_dir=[str(nemo_path)], - lora_ckpt_source="nemo", - lora_target_modules=["attn_qkv", - "mlp_h_to_4h"], # mlp_h_to_4h not supported - max_lora_rank=8, - ) + # Test validation: should fail with unsupported modules + invalid_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + lora_target_modules=["attn_qkv", + "mlp_h_to_4h"], # mlp_h_to_4h not supported + max_lora_rank=8, + ) - with pytest.raises(ValueError, match="NeMo LoRA only supports"): - load_torch_nemo_lora(invalid_config) + with pytest.raises(ValueError, match="NeMo LoRA only supports"): + load_torch_nemo_lora(invalid_config) @force_ampere -def test_gqa_nemo_lora(): +def test_gqa_nemo_lora(tmp_path): """Test NeMo LoRA with GQA using TinyLlama. """ - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # TinyLlama's exact GQA configuration - hidden_size = 2048 - num_layers = 22 - num_q_heads = 32 # Query attention heads - num_kv_heads = 4 # Key/Value heads (GQA) - lora_rank = 8 - - nemo_path = create_mock_nemo_lora_checkpoint( - temp_path, - hidden_size=hidden_size, - num_layers=num_layers, - lora_rank=lora_rank, - num_attention_heads=num_q_heads, - num_kv_heads=num_kv_heads, - ) + # TinyLlama's exact GQA configuration + hidden_size = 2048 + num_layers = 22 + num_q_heads = 32 # Query attention heads + num_kv_heads = 4 # Key/Value heads (GQA) + lora_rank = 8 + + nemo_path = create_mock_nemo_lora_checkpoint( + tmp_path, + hidden_size=hidden_size, + num_layers=num_layers, + lora_rank=lora_rank, + num_attention_heads=num_q_heads, + num_kv_heads=num_kv_heads, + ) - lora_config = LoraConfig( - lora_dir=[str(nemo_path)], - lora_ckpt_source="nemo", - max_lora_rank=lora_rank, - ) + lora_config = LoraConfig( + lora_dir=[str(nemo_path)], + lora_ckpt_source="nemo", + max_lora_rank=lora_rank, + ) - model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") - try: - llm = LLM( - model=model_path, - lora_config=lora_config, - kv_cache_config=global_kvcache_config, - ) + try: + llm = LLM( + model=model_path, + lora_config=lora_config, + kv_cache_config=global_kvcache_config, + ) - test_prompts = ["Test TinyLlama GQA with NeMo LoRA"] + test_prompts = ["Test TinyLlama GQA with NeMo LoRA"] - lora_req = LoRARequest("tinyllama-gqa-test", - 0, - str(nemo_path), - lora_ckpt_source="nemo") + lora_req = LoRARequest("tinyllama-gqa-test", + 0, + str(nemo_path), + lora_ckpt_source="nemo") - sampling_params = SamplingParams(max_tokens=10, temperature=0.0) - outputs = llm.generate(test_prompts, - sampling_params, - lora_request=[lora_req]) + sampling_params = SamplingParams(max_tokens=10, temperature=0.0) + outputs = llm.generate(test_prompts, + sampling_params, + lora_request=[lora_req]) - # Basic validation - assert len(outputs) == 1 - assert outputs[0].outputs[0] is not None - assert len(outputs[0].outputs[0].token_ids) > 0 + # Basic validation + assert len(outputs) == 1 + assert outputs[0].outputs[0] is not None + assert len(outputs[0].outputs[0].token_ids) > 0 - print(f" ✓ TinyLlama GQA with NeMo LoRA passed successfully!") + print(f" ✓ TinyLlama GQA with NeMo LoRA passed successfully!") - except Exception as e: - # Any error now indicates a real problem since dimensions match - pytest.fail(f"TinyLlama GQA test failed: {e}") - finally: - if 'llm' in locals(): - llm.shutdown() + except Exception as e: + # Any error now indicates a real problem since dimensions match + pytest.fail(f"TinyLlama GQA test failed: {e}") + finally: + if 'llm' in locals(): + llm.shutdown() From 1507c2c7a3130e3fefd6198d2da6d9a48857650b Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Mon, 21 Jul 2025 18:55:46 -0700 Subject: [PATCH 11/16] validate outputs in e2e test, cosmetics Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 44 +++--- tensorrt_llm/_torch/pyexecutor/_util.py | 6 +- tensorrt_llm/lora_manager.py | 9 +- tests/unittest/llmapi/lora_test_utils.py | 98 ++++++++++++++ tests/unittest/llmapi/test_llm_pytorch.py | 158 +++++++--------------- 5 files changed, 179 insertions(+), 136 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index d47cc9f831b..d3faea366a7 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -299,17 +299,12 @@ def get_bindings_model_config(self, self.mapping.tp_size * self.mapping.cp_size) # Handle both uniform and per-layer KV heads - if hasattr( - self.pretrained_config, 'num_kv_heads_per_layer' - ) and self.pretrained_config.num_kv_heads_per_layer is not None: + num_kv_heads_per_layer = getattr(self.pretrained_config, + 'num_kv_heads_per_layer', None) + if num_kv_heads_per_layer is not None: # For models with per-layer KV heads, like nemotron-nas num_kv_heads_per_layer_raw = self.pretrained_config.num_kv_heads_per_layer - # Apply TP/CP scaling to each layer - num_kv_heads_per_layer = [ - kv_heads // (self.mapping.tp_size * self.mapping.cp_size) - for kv_heads in num_kv_heads_per_layer_raw - ] - _use_per_layer_kv_heads = True + use_per_layer_kv_heads = True else: # Check if num_key_value_heads is a list (per-layer) or scalar (uniform) num_kv_heads_raw = getattr(self.pretrained_config, @@ -319,20 +314,33 @@ def get_bindings_model_config(self, num_kv_heads_raw, list): # num_key_value_heads is a list - treat as per-layer KV heads num_kv_heads_per_layer_raw = num_kv_heads_raw - # Apply TP/CP scaling to each layer - num_kv_heads_per_layer = [ - kv_heads // (self.mapping.tp_size * self.mapping.cp_size) - for kv_heads in num_kv_heads_per_layer_raw - ] - _use_per_layer_kv_heads = True + use_per_layer_kv_heads = True else: # num_key_value_heads is scalar or None - treat as uniform KV heads if num_kv_heads_raw is None: - num_kv_heads_raw = self.pretrained_config.num_attention_heads + # For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads + num_kv_heads_raw = getattr( + self.pretrained_config, 'num_query_groups', + self.pretrained_config.num_attention_heads) num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size * self.mapping.cp_size) - _use_per_layer_kv_heads = False + use_per_layer_kv_heads = False + + if use_per_layer_kv_heads: + # TRT-LLM LoRA requires uniform KV heads across layers + if self.lora_config is not None and not all( + kv == num_kv_heads_per_layer_raw[0] + for kv in num_kv_heads_per_layer_raw): + kv_heads_list = self.pretrained_config.num_kv_heads_per_layer + raise ValueError( + f"TRT-LLM LoRA requires uniform KV heads across layers, " + f"got: {kv_heads_list}") + # Apply TP/CP scaling to each layer + num_kv_heads_per_layer = [ + kv_heads // (self.mapping.tp_size * self.mapping.cp_size) + for kv_heads in num_kv_heads_per_layer_raw + ] hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size @@ -354,7 +362,7 @@ def get_bindings_model_config(self, else: model_config_cpp.tokens_per_block = tokens_per_block - if _use_per_layer_kv_heads: + if use_per_layer_kv_heads: model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer else: model_config_cpp.set_num_kv_heads(num_kv_heads) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 7a772182345..f5415b56de8 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -455,13 +455,13 @@ def create_py_executor_instance( per_layer_kv_heads = [ model_binding_config.num_kv_heads(i) for i in range(num_attn_layers) ] - num_kv_attention_heads = per_layer_kv_heads[0] if len( - set(per_layer_kv_heads)) == 1 else max(per_layer_kv_heads) + num_kv_attention_heads = max(per_layer_kv_heads) if len(set(per_layer_kv_heads)) > 1: # NOTE: This code-path is currently untested and not validated. Can fail! # This support is tracked in TRTLLM-6561 logger.warning( - f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. This code-path is currently untested and not validated. May fail!" + f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. " + "This code-path is currently untested and not validated. May fail!" ) lora_modules = LoraModule.create_lora_modules( diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index a73e6394c59..9f42fdad20d 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -67,7 +67,10 @@ def get_all_nemo_lora_weights( def iterate_hf_lora( - iter_fn, lora_weights: Dict[str, torch.Tensor], hf_modules: set, component: Optional[str] = None + iter_fn, + lora_weights: Dict[str, torch.Tensor], + hf_modules: Set[str], + component: Optional[str] = None, ): """Iterate over HuggingFace LoRA weights and call iterator function for each weight. @@ -129,7 +132,7 @@ def iterate_hf_lora( def get_all_hf_lora_weights( - lora_weights: Dict[str, torch.Tensor], hf_modules: set, component: Optional[str] = None + lora_weights: Dict[str, torch.Tensor], hf_modules: Set[str], component: Optional[str] = None ): """Extract and organize all HuggingFace LoRA weights by layer and module. diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index 1b2323804fa..c8426979682 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -1,5 +1,11 @@ +# NeMo LoRA test utilities +import json +import tarfile +import tempfile +from pathlib import Path from typing import OrderedDict, Type +import torch from utils.llm_data import llm_models_root from utils.util import duplicate_list_to_length, flatten_list, similar @@ -114,3 +120,95 @@ def check_llama_7b_multi_lora_from_request_test_harness( for output, ref, key_word in zip(outputs, references, key_words): assert similar(output.outputs[0].text, ref) or key_word in output.outputs[0].text + + +def create_mock_nemo_lora_checkpoint( + lora_dir: Path, + hidden_size: int = 4096, + num_layers: int = 32, + lora_rank: int = 8, + tp_size: int = 1, + num_attention_heads: int = 32, + num_kv_heads: int = None, # If None, defaults to num_attention_heads + dtype: torch.dtype = torch.float16, +) -> Path: + """Create a minimal NeMo LoRA checkpoint for testing. + + This creates a .nemo tarfile with the expected structure: + - model_weights.ckpt containing attn_qkv adapter weights + - model_config.yaml with basic configuration + + Args: + lora_dir: Directory to create the checkpoint in + hidden_size: Model hidden size + num_layers: Number of transformer layers + lora_rank: LoRA rank + tp_size: Tensor parallelism size + num_attention_heads: Number of query attention heads + num_kv_heads: Number of key/value heads (for GQA). If None, equals num_attention_heads + dtype: Data type for the weights (default: torch.float16) + + Returns: + Path to the created .nemo file + """ + # Default to standard MHA if not specified + if num_kv_heads is None: + num_kv_heads = num_attention_heads + + nemo_path = lora_dir / "test_lora.nemo" + + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + + weights_dict = {} + + head_dim = hidden_size // num_attention_heads + kv_hidden_size = head_dim * num_kv_heads + + qkv_output_dim = hidden_size + 2 * kv_hidden_size + + for layer_idx in range(num_layers): + key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" + + # Create linear_in weights [lora_rank, hidden_size] with small random values + linear_in_key = f"{key_prefix}.linear_in.weight" + weights_dict[linear_in_key] = torch.randn( + lora_rank, hidden_size, dtype=dtype) * 0.01 + + # Create linear_out weights [qkv_output_dim, lora_rank] for fused QKV + # This is the key difference for GQA - the output dimension changes + linear_out_key = f"{key_prefix}.linear_out.weight" + weights_dict[linear_out_key] = torch.randn( + qkv_output_dim, lora_rank, dtype=dtype) * 0.01 + + ckpt_path = temp_dir / "model_weights.ckpt" + torch.save(weights_dict, ckpt_path) + + config = { + "precision": "fp16" if dtype == torch.float16 else "bf16", + "trainer": { + "num_nodes": 1, + "devices": tp_size, + }, + "model": { + "hidden_size": hidden_size, + "num_layers": num_layers, + "num_attention_heads": num_attention_heads, + "num_query_groups": num_kv_heads, # This is the key for GQA + }, + "lora": { + "rank": lora_rank, + "target_modules": ["attn_qkv"], + } + } + + config_path = temp_dir / "model_config.yaml" + # Using JSON for simplicity since YAML parsing isn't critical for the test + with open(config_path, 'w') as f: + json.dump(config, f) + + with tarfile.open(nemo_path, 'w') as tar: + tar.add(ckpt_path, arcname="model_weights.ckpt") + tar.add(config_path, arcname="model_config.yaml") + + return nemo_path diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 18fea69977f..5e269f9eaa9 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -5,7 +5,7 @@ from tensorrt_llm.sampling_params import SamplingParams # isort: off -from .lora_test_utils import check_llama_7b_multi_unique_lora_adapters_from_request +from .lora_test_utils import check_llama_7b_multi_unique_lora_adapters_from_request, create_mock_nemo_lora_checkpoint from .test_llm import (get_model_path, global_kvcache_config, llama_model_path, llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts, @@ -29,10 +29,6 @@ from peft import get_peft_model from transformers import AutoModelForCausalLM -import json -import tarfile -from pathlib import Path - # isort: on # NeMo LoRA test data @@ -44,96 +40,6 @@ ] -def create_mock_nemo_lora_checkpoint( - lora_dir: Path, - hidden_size: int = 4096, - num_layers: int = 32, - lora_rank: int = 8, - tp_size: int = 1, - num_attention_heads: int = 32, - num_kv_heads: int = None, # If None, defaults to num_attention_heads -) -> Path: - """Create a minimal NeMo LoRA checkpoint for testing. - - This creates a .nemo tarfile with the expected structure: - - model_weights.ckpt containing attn_qkv adapter weights - - model_config.yaml with basic configuration - - Args: - lora_dir: Directory to create the checkpoint in - hidden_size: Model hidden size - num_layers: Number of transformer layers - lora_rank: LoRA rank - tp_size: Tensor parallelism size - num_attention_heads: Number of query attention heads - num_kv_heads: Number of key/value heads (for GQA). If None, equals num_attention_heads - - Returns: - Path to the created .nemo file - """ - # Default to standard MHA if not specified - if num_kv_heads is None: - num_kv_heads = num_attention_heads - - nemo_path = lora_dir / "test_lora.nemo" - - with tempfile.TemporaryDirectory() as temp_dir_str: - temp_dir = Path(temp_dir_str) - - weights_dict = {} - - head_dim = hidden_size // num_attention_heads - kv_hidden_size = head_dim * num_kv_heads - - qkv_output_dim = hidden_size + 2 * kv_hidden_size - - for layer_idx in range(num_layers): - key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" - - # Create linear_in weights [lora_rank, hidden_size] with small random values - linear_in_key = f"{key_prefix}.linear_in.weight" - weights_dict[linear_in_key] = torch.randn( - lora_rank, hidden_size, dtype=torch.float16) * 0.01 - - # Create linear_out weights [qkv_output_dim, lora_rank] for fused QKV - # This is the key difference for GQA - the output dimension changes - linear_out_key = f"{key_prefix}.linear_out.weight" - weights_dict[linear_out_key] = torch.randn( - qkv_output_dim, lora_rank, dtype=torch.float16) * 0.01 - - ckpt_path = temp_dir / "model_weights.ckpt" - torch.save(weights_dict, ckpt_path) - - config = { - "precision": "fp16", - "trainer": { - "num_nodes": 1, - "devices": tp_size, - }, - "model": { - "hidden_size": hidden_size, - "num_layers": num_layers, - "num_attention_heads": num_attention_heads, - "num_query_groups": num_kv_heads, # This is the key for GQA - }, - "lora": { - "rank": lora_rank, - "target_modules": ["attn_qkv"], - } - } - - config_path = temp_dir / "model_config.yaml" - # Using JSON for simplicity since YAML parsing isn't critical for the test - with open(config_path, 'w') as f: - json.dump(config, f) - - with tarfile.open(nemo_path, 'w') as tar: - tar.add(ckpt_path, arcname="model_weights.ckpt") - tar.add(config_path, arcname="model_config.yaml") - - return nemo_path - - @force_ampere def test_tinyllama_logits_processor(): tinyllama_logits_processor_test_harness(backend="pytorch") @@ -594,9 +500,7 @@ def test_nemo_lora_unsupported_modules_validation(tmp_path): @force_ampere def test_gqa_nemo_lora(tmp_path): - """Test NeMo LoRA with GQA using TinyLlama. - - """ + """Test NeMo LoRA with GQA using TinyLlama.""" # TinyLlama's exact GQA configuration hidden_size = 2048 num_layers = 22 @@ -621,35 +525,65 @@ def test_gqa_nemo_lora(tmp_path): model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") + # First, generate without LoRA + llm_no_lora = LLM( + model=model_path, + kv_cache_config=global_kvcache_config, + ) + try: - llm = LLM( - model=model_path, - lora_config=lora_config, - kv_cache_config=global_kvcache_config, - ) + test_prompts = ["The capital of France is"] + sampling_params = SamplingParams(max_tokens=10, temperature=0.0) - test_prompts = ["Test TinyLlama GQA with NeMo LoRA"] + # Generate without LoRA + outputs_no_lora = llm_no_lora.generate(test_prompts, sampling_params) + no_lora_text = outputs_no_lora[0].outputs[0].text + finally: + llm_no_lora.shutdown() + + # Now generate with LoRA + llm = LLM( + model=model_path, + lora_config=lora_config, + kv_cache_config=global_kvcache_config, + ) + try: lora_req = LoRARequest("tinyllama-gqa-test", 0, str(nemo_path), lora_ckpt_source="nemo") - sampling_params = SamplingParams(max_tokens=10, temperature=0.0) outputs = llm.generate(test_prompts, sampling_params, lora_request=[lora_req]) - # Basic validation + # Validate output assert len(outputs) == 1 assert outputs[0].outputs[0] is not None assert len(outputs[0].outputs[0].token_ids) > 0 - print(f" ✓ TinyLlama GQA with NeMo LoRA passed successfully!") + # Compare with and without LoRA + lora_text = outputs[0].outputs[0].text + assert lora_text, "Generated text with LoRA should not be empty" + + # Since LoRA weights are initialized with small values (* 0.01), + # the outputs should start similarly but may diverge + # Check that both outputs start with the expected completion "Paris" + assert "Paris" in lora_text or "paris" in lora_text.lower(), \ + f"LoRA output should contain 'Paris', got: {lora_text}" + assert "Paris" in no_lora_text or "paris" in no_lora_text.lower(), \ + f"No-LoRA output should contain 'Paris', got: {no_lora_text}" + + # For very small LoRA weights, at least the first few tokens should be similar + # Check if the outputs are at least 60% similar or start with the same word + if not similar(lora_text, no_lora_text, threshold=0.6): + # If not similar enough, at least check they start with the same word + first_word_lora = lora_text.split()[0] if lora_text.split() else "" + first_word_no_lora = no_lora_text.split()[0] if no_lora_text.split( + ) else "" + assert first_word_lora.lower() == first_word_no_lora.lower(), \ + f"First words should match: LoRA='{first_word_lora}' vs No-LoRA='{first_word_no_lora}'" - except Exception as e: - # Any error now indicates a real problem since dimensions match - pytest.fail(f"TinyLlama GQA test failed: {e}") finally: - if 'llm' in locals(): - llm.shutdown() + llm.shutdown() From 05d8578a4901c70691b9d39beb74dabbb1a98ae2 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 22 Jul 2025 00:07:01 -0700 Subject: [PATCH 12/16] util input validation, et al. Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tests/unittest/llmapi/lora_test_utils.py | 11 +++++++++++ tests/unittest/llmapi/test_llm_pytorch.py | 24 +++++------------------ 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index c8426979682..2af9a7feb8f 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -151,10 +151,21 @@ def create_mock_nemo_lora_checkpoint( Returns: Path to the created .nemo file """ + + # Validate parameters + if hidden_size % num_attention_heads != 0: + raise ValueError(f"hidden_size ({hidden_size}) must be divisible by " + f"num_attention_heads ({num_attention_heads})") + # Default to standard MHA if not specified if num_kv_heads is None: num_kv_heads = num_attention_heads + if num_attention_heads % num_kv_heads != 0: + raise ValueError( + f"num_attention_heads ({num_attention_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads}) for GQA") + nemo_path = lora_dir / "test_lora.nemo" with tempfile.TemporaryDirectory() as temp_dir_str: diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 5e269f9eaa9..363788d7ce5 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -31,14 +31,6 @@ # isort: on -# NeMo LoRA test data -LORA_RANK_CONFIGS = [ - # (lora_rank, max_lora_rank, description) - (8, 8, "rank_8"), - (16, 16, "rank_16"), - (4, 8, "rank_4_max_8"), -] - @force_ampere def test_tinyllama_logits_processor(): @@ -567,23 +559,17 @@ def test_gqa_nemo_lora(tmp_path): lora_text = outputs[0].outputs[0].text assert lora_text, "Generated text with LoRA should not be empty" - # Since LoRA weights are initialized with small values (* 0.01), - # the outputs should start similarly but may diverge # Check that both outputs start with the expected completion "Paris" assert "Paris" in lora_text or "paris" in lora_text.lower(), \ f"LoRA output should contain 'Paris', got: {lora_text}" assert "Paris" in no_lora_text or "paris" in no_lora_text.lower(), \ f"No-LoRA output should contain 'Paris', got: {no_lora_text}" - # For very small LoRA weights, at least the first few tokens should be similar - # Check if the outputs are at least 60% similar or start with the same word - if not similar(lora_text, no_lora_text, threshold=0.6): - # If not similar enough, at least check they start with the same word - first_word_lora = lora_text.split()[0] if lora_text.split() else "" - first_word_no_lora = no_lora_text.split()[0] if no_lora_text.split( - ) else "" - assert first_word_lora.lower() == first_word_no_lora.lower(), \ - f"First words should match: LoRA='{first_word_lora}' vs No-LoRA='{first_word_no_lora}'" + # Since dummy LoRA weights are initialized with small values, + # the outputs should be similar. 60% was trial and error. + assert similar( + lora_text, no_lora_text, + threshold=0.60), "LoRA and no LoRA outputs should be similar" finally: llm.shutdown() From aeae97464b34b18beae64036b48413f438f7971f Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 22 Jul 2025 10:38:30 -0700 Subject: [PATCH 13/16] make test deterministic and more robust, et al. Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 14 ++-- tests/unittest/llmapi/lora_test_utils.py | 31 +++++---- tests/unittest/llmapi/test_llm_pytorch.py | 78 ++++++++++------------- 3 files changed, 61 insertions(+), 62 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index d3faea366a7..3d0175a3c23 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -303,7 +303,7 @@ def get_bindings_model_config(self, 'num_kv_heads_per_layer', None) if num_kv_heads_per_layer is not None: # For models with per-layer KV heads, like nemotron-nas - num_kv_heads_per_layer_raw = self.pretrained_config.num_kv_heads_per_layer + kv_heads_per_layer_raw = num_kv_heads_per_layer use_per_layer_kv_heads = True else: # Check if num_key_value_heads is a list (per-layer) or scalar (uniform) @@ -313,7 +313,7 @@ def get_bindings_model_config(self, if num_kv_heads_raw is not None and isinstance( num_kv_heads_raw, list): # num_key_value_heads is a list - treat as per-layer KV heads - num_kv_heads_per_layer_raw = num_kv_heads_raw + kv_heads_per_layer_raw = num_kv_heads_raw use_per_layer_kv_heads = True else: # num_key_value_heads is scalar or None - treat as uniform KV heads @@ -329,17 +329,15 @@ def get_bindings_model_config(self, if use_per_layer_kv_heads: # TRT-LLM LoRA requires uniform KV heads across layers - if self.lora_config is not None and not all( - kv == num_kv_heads_per_layer_raw[0] - for kv in num_kv_heads_per_layer_raw): - kv_heads_list = self.pretrained_config.num_kv_heads_per_layer + if self.lora_config is not None and len( + set(kv_heads_per_layer_raw)) > 1: raise ValueError( f"TRT-LLM LoRA requires uniform KV heads across layers, " - f"got: {kv_heads_list}") + f"got: {kv_heads_per_layer_raw}") # Apply TP/CP scaling to each layer num_kv_heads_per_layer = [ kv_heads // (self.mapping.tp_size * self.mapping.cp_size) - for kv_heads in num_kv_heads_per_layer_raw + for kv_heads in kv_heads_per_layer_raw ] hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index 2af9a7feb8f..58673aa0699 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -1,4 +1,3 @@ -# NeMo LoRA test utilities import json import tarfile import tempfile @@ -123,14 +122,15 @@ def check_llama_7b_multi_lora_from_request_test_harness( def create_mock_nemo_lora_checkpoint( - lora_dir: Path, - hidden_size: int = 4096, - num_layers: int = 32, - lora_rank: int = 8, - tp_size: int = 1, - num_attention_heads: int = 32, - num_kv_heads: int = None, # If None, defaults to num_attention_heads - dtype: torch.dtype = torch.float16, + lora_dir: Path, + hidden_size: int = 4096, + num_layers: int = 32, + lora_rank: int = 8, + tp_size: int = 1, + num_attention_heads: int = 32, + num_kv_heads: int = None, # If None, defaults to num_attention_heads + dtype: torch.dtype = torch.float16, + seed: int = None, # For deterministic weight initialization ) -> Path: """Create a minimal NeMo LoRA checkpoint for testing. @@ -171,6 +171,10 @@ def create_mock_nemo_lora_checkpoint( with tempfile.TemporaryDirectory() as temp_dir_str: temp_dir = Path(temp_dir_str) + # Set random seed for deterministic weight initialization + if seed is not None: + torch.manual_seed(seed) + weights_dict = {} head_dim = hidden_size // num_attention_heads @@ -178,19 +182,24 @@ def create_mock_nemo_lora_checkpoint( qkv_output_dim = hidden_size + 2 * kv_hidden_size + # NOTE: + # for seed=42, and coefficient=0.02, the expected outputs are hardcoded + # in the test `test_llm_pytorch.py::test_gqa_nemo_lora`. + # Therefore changing "WEIGHTS_COEFFICIENT" or the seed will break the test. + WEIGHTS_COEFFICIENT = 0.02 for layer_idx in range(num_layers): key_prefix = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter" # Create linear_in weights [lora_rank, hidden_size] with small random values linear_in_key = f"{key_prefix}.linear_in.weight" weights_dict[linear_in_key] = torch.randn( - lora_rank, hidden_size, dtype=dtype) * 0.01 + lora_rank, hidden_size, dtype=dtype) * WEIGHTS_COEFFICIENT # Create linear_out weights [qkv_output_dim, lora_rank] for fused QKV # This is the key difference for GQA - the output dimension changes linear_out_key = f"{key_prefix}.linear_out.weight" weights_dict[linear_out_key] = torch.randn( - qkv_output_dim, lora_rank, dtype=dtype) * 0.01 + qkv_output_dim, lora_rank, dtype=dtype) * WEIGHTS_COEFFICIENT ckpt_path = temp_dir / "model_weights.ckpt" torch.save(weights_dict, ckpt_path) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 363788d7ce5..7e890693e50 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -492,7 +492,18 @@ def test_nemo_lora_unsupported_modules_validation(tmp_path): @force_ampere def test_gqa_nemo_lora(tmp_path): - """Test NeMo LoRA with GQA using TinyLlama.""" + """ + Test NeMo-format LoRA checkpoint loading and GQA support in TinyLlama. + + This test verifies two properties: + 1. That a NeMo-format LoRA checkpoint with GQA (grouped query attention) can be loaded and applied to a TinyLlama model, + and that generation with this LoRA produces a deterministic, expected output for a fixed prompt and temperature=0.0. + 2. That the LoRA weights have a significant effect: generating with LoRA produces a different output than generating + without LoRA, confirming that the LoRA adapter is actually being applied. + + The test uses a deterministic dummy LoRA checkpoint (seed=42) and checks both the positive (LoRA applied) and negative + (no LoRA) cases for output text. + """ # TinyLlama's exact GQA configuration hidden_size = 2048 num_layers = 22 @@ -507,7 +518,11 @@ def test_gqa_nemo_lora(tmp_path): lora_rank=lora_rank, num_attention_heads=num_q_heads, num_kv_heads=num_kv_heads, + seed=42, # NOTE: the seed=42 is important for the test to pass. ) + expected_lora_text_output = "Paris. The capital of France is Paris. The" + test_prompts = ["The capital of France is"] + sampling_params = SamplingParams(max_tokens=10, temperature=0.0) lora_config = LoraConfig( lora_dir=[str(nemo_path)], @@ -517,23 +532,6 @@ def test_gqa_nemo_lora(tmp_path): model_path = get_model_path("llama-models-v2/TinyLlama-1.1B-Chat-v1.0") - # First, generate without LoRA - llm_no_lora = LLM( - model=model_path, - kv_cache_config=global_kvcache_config, - ) - - try: - test_prompts = ["The capital of France is"] - sampling_params = SamplingParams(max_tokens=10, temperature=0.0) - - # Generate without LoRA - outputs_no_lora = llm_no_lora.generate(test_prompts, sampling_params) - no_lora_text = outputs_no_lora[0].outputs[0].text - finally: - llm_no_lora.shutdown() - - # Now generate with LoRA llm = LLM( model=model_path, lora_config=lora_config, @@ -546,30 +544,24 @@ def test_gqa_nemo_lora(tmp_path): str(nemo_path), lora_ckpt_source="nemo") - outputs = llm.generate(test_prompts, - sampling_params, - lora_request=[lora_req]) - - # Validate output - assert len(outputs) == 1 - assert outputs[0].outputs[0] is not None - assert len(outputs[0].outputs[0].token_ids) > 0 - - # Compare with and without LoRA - lora_text = outputs[0].outputs[0].text - assert lora_text, "Generated text with LoRA should not be empty" - - # Check that both outputs start with the expected completion "Paris" - assert "Paris" in lora_text or "paris" in lora_text.lower(), \ - f"LoRA output should contain 'Paris', got: {lora_text}" - assert "Paris" in no_lora_text or "paris" in no_lora_text.lower(), \ - f"No-LoRA output should contain 'Paris', got: {no_lora_text}" - - # Since dummy LoRA weights are initialized with small values, - # the outputs should be similar. 60% was trial and error. - assert similar( - lora_text, no_lora_text, - threshold=0.60), "LoRA and no LoRA outputs should be similar" - + lora_outputs = llm.generate(test_prompts, + sampling_params, + lora_request=[lora_req]) + + # For the above deterministic dummy LoRA checkpoint, + # with temperature=0.0, + # the expected output text should always be the same. + assert lora_outputs[0].outputs[0].text == expected_lora_text_output, \ + f"Expected output text: {expected_lora_text_output}, " \ + f"got: {lora_outputs[0].outputs[0].text}" + assert len(lora_outputs) == 1 + + # Generate without LoRA. + # The LoRA weights are tuned/large enough that + # they differ from a no-LoRA run. + base_outputs = llm.generate(test_prompts, sampling_params) + assert base_outputs[0].outputs[0].text != expected_lora_text_output, \ + f"No-LoRA output should differ from expected output text: {expected_lora_text_output}, " \ + f"got: {base_outputs[0].outputs[0].text}" finally: llm.shutdown() From 58a2ca37c6bf7ad6d0ba722a8ffd406f9652c73d Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 22 Jul 2025 20:22:07 -0700 Subject: [PATCH 14/16] initial commit Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 122 ++++++- tests/unittest/test_lora_manager.py | 518 ++++++++++++++++++++++++++++ 2 files changed, 625 insertions(+), 15 deletions(-) create mode 100644 tests/unittest/test_lora_manager.py diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 9f42fdad20d..d3539d7bfde 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -2,6 +2,7 @@ import json import re import tarfile +import warnings from collections import defaultdict from dataclasses import dataclass, field from functools import lru_cache @@ -821,30 +822,102 @@ def load_from_model_file(uid, model_file): if uid not in self._cpp_lora_config: self._cpp_lora_config[uid] = [] # Will be converted to tensor later - _, nemo_weights = unpack_nemo_weights(model_file) + nemo_model_config, nemo_weights = unpack_nemo_weights(model_file) all_lora_weights = get_all_nemo_lora_weights(nemo_weights) + # Extract rank from NeMo model config (more reliable than deriving from tensors) + config_rank = None + if ( + "lora_tuning" in nemo_model_config + and "adapter_dim" in nemo_model_config["lora_tuning"] + ): + config_rank = nemo_model_config["lora_tuning"]["adapter_dim"] + + # If rank not found in config, fall back to tensor-based derivation + if config_rank is None: + warnings.warn( + "Could not find lora_tuning.adapter_dim in NeMo model config, " + "will derive from tensor shapes" + ) + self._lora_uid_to_low_ranks[uid] = {} self._lora_weights_pointers_list[uid] = {} - for layer_idx in sorted(all_lora_weights.keys()): + + # Determine expected number of layers from model config or infer from available weights + num_layers = nemo_model_config.get("num_layers") + if num_layers is None: + # Fallback: infer from available weights + num_layers = max(all_lora_weights.keys()) + 1 if all_lora_weights else 1 + warnings.warn( + f"Could not find num_layers in NeMo model config, " + f"inferring {num_layers} layers from weights" + ) + + # Process all expected layers (not just existing ones) + for layer_idx in range(num_layers): self._lora_uid_to_low_ranks[uid][layer_idx] = {} self._lora_weights_pointers_list[uid][layer_idx] = {} for lora_module in self.lora_target_modules: if lora_module != "attn_qkv": + warnings.warn( + f"LoRA module '{lora_module}' not supported in NeMo loading, skipping." + ) self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0 continue - if lora_module == "attn_qkv": - t_in = all_lora_weights[layer_idx]["in"] - t_out = all_lora_weights[layer_idx]["out"] - assert t_out.shape[0] % tp_size == 0 - t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[ - tp_rank - ].contiguous() - else: - t_in = None - t_out = None + # At this point, lora_module must be "attn_qkv" + # Graceful handling of missing matrices with warnings and zero tensor fallbacks + layer_weights = all_lora_weights.get( + layer_idx, {} + ) # Use get() to handle missing layers + + # Determine rank: prefer config rank, fall back to tensor-based derivation + rank = config_rank + if rank is None: + if "in" in layer_weights: + rank = layer_weights["in"].shape[0] + elif "out" in layer_weights: + rank = layer_weights["out"].shape[1] + else: + # Both matrices missing - look for rank from other layers or use default + for other_layer_idx in sorted(all_lora_weights.keys()): + if ( + other_layer_idx != layer_idx + and "in" in all_lora_weights[other_layer_idx] + ): + rank = all_lora_weights[other_layer_idx]["in"].shape[0] + break + if rank is None: + # Final fallback to a reasonable default rank + rank = 64 + warnings.warn( + f"Layer {layer_idx}: No reference rank found for attn_qkv, " + f"using default rank {rank}" + ) + + # Handle missing "in" matrix (lora_A equivalent) + if "in" not in layer_weights: + warnings.warn( + f"Layer {layer_idx} is missing 'in' matrix for attn_qkv in NeMo LoRA weights, " + f"creating zero tensor" + ) + layer_weights["in"] = torch.zeros(rank, model_config.hidden_size) + + # Handle missing "out" matrix (lora_B equivalent) - 3x larger for fused QKV + if "out" not in layer_weights: + warnings.warn( + f"Layer {layer_idx} is missing 'out' matrix for attn_qkv in NeMo LoRA weights, " + f"creating zero tensor" + ) + layer_weights["out"] = torch.zeros(3 * model_config.hidden_size, rank) + + t_in = layer_weights["in"] + t_out = layer_weights["out"] + assert t_out.shape[0] % tp_size == 0 + t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[ + tp_rank + ].contiguous() if t_in is not None and t_out is not None: t_in = t_in.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous() @@ -883,6 +956,8 @@ def load_from_model_file(uid, model_file): load_from_model_file(uid, model_file) release_gc() + if new_uids: + print(f"Successfully loaded NeMo LoRA adapters with UIDs: {new_uids}") return new_uids def load_from_hf( @@ -1022,10 +1097,15 @@ def load_from_model_dir(uid, model_dir, hf_config): for hf_module, module_weights in layer_weights.items(): lora_module = hf_modules_to_trtllm_modules[hf_module] if lora_module not in self.lora_target_modules: + warnings.warn( + f"LoRA module '{lora_module}' not in target modules {self.lora_target_modules}, skipping." + ) self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0 continue - if "in" not in module_weights: - is_moe = True + + is_moe = "in" not in module_weights and "out" not in module_weights + + if is_moe: t_in = torch.stack( [ module_weights[expert_idx]["in"] @@ -1044,7 +1124,17 @@ def load_from_model_dir(uid, model_dir, hf_config): raise ValueError("DoRA with MoE is not supported") t_mag = None else: - is_moe = False + if "in" not in module_weights: + warnings.warn( + f"Module {hf_module} is missing 'in' matrix, creating zero tensor" + ) + module_weights["in"] = torch.zeros(rank, model_config.hidden_size) + if "out" not in module_weights: + warnings.warn( + f"Module {hf_module} is missing 'out' matrix, creating zero tensor" + ) + module_weights["out"] = torch.zeros(model_config.hidden_size, rank) + t_in = module_weights["in"] t_out = module_weights["out"] t_mag = module_weights.get("magnitude", None) @@ -1139,6 +1229,8 @@ def load_from_model_dir(uid, model_dir, hf_config): load_from_model_dir(uid, model_dir, hf_config) release_gc() + if new_uids: + print(f"Successfully loaded HF LoRA adapters with UIDs: {new_uids}") return new_uids @property diff --git a/tests/unittest/test_lora_manager.py b/tests/unittest/test_lora_manager.py new file mode 100644 index 00000000000..b517361baa9 --- /dev/null +++ b/tests/unittest/test_lora_manager.py @@ -0,0 +1,518 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import os +import tarfile +import tempfile +import unittest +import warnings +from unittest import mock + +import torch +import yaml +from safetensors import torch as safetensors_torch + +# Import the modules to test +from tensorrt_llm import lora_manager, mapping + +# Constants +DEFAULT_HIDDEN_SIZE = 4096 +DEFAULT_RANK = 32 +DEFAULT_NUM_LAYERS = 4 +DEFAULT_TEST_RANK = 16 + + +class TestLoraManagerBase(unittest.TestCase): + """Base class with common functionality for LoRA manager tests.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.lora_manager = lora_manager.LoraManager() + + def tearDown(self): + """Clean up test fixtures.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _create_incomplete_hf_checkpoint(self, missing_matrices=None): + """ + Create an incomplete HF LoRA checkpoint for testing. + + Args: + missing_matrices: List of matrix types to exclude (e.g., ['q_proj.lora_A']) + If None, defaults to ['q_proj.lora_A', 'v_proj.lora_A'] + + Returns: + str: Path to the created checkpoint directory + """ + if missing_matrices is None: + missing_matrices = ['q_proj.lora_A', 'v_proj.lora_A'] + + # Create adapter_config.json + adapter_config = { + "r": 32, + "lora_alpha": 32, + "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], + "lora_dropout": 0.05, + "bias": "none", + "task_type": "CAUSAL_LM" + } + + config_path = os.path.join(self.temp_dir, "adapter_config.json") + with open(config_path, 'w') as f: + json.dump(adapter_config, f) + + # Create incomplete weight tensors + weights = {} + hidden_size = DEFAULT_HIDDEN_SIZE + rank = DEFAULT_RANK + num_layers = DEFAULT_NUM_LAYERS # Use fewer layers for faster tests + + for layer_idx in range(num_layers): + layer_prefix = f"base_model.model.model.layers.{layer_idx}.self_attn" + + # Add weights for all modules except missing ones + for module in ["q_proj", "k_proj", "v_proj", "o_proj"]: + for matrix_type in ["lora_A", "lora_B"]: + key = f"{layer_prefix}.{module}.{matrix_type}.weight" + + # Skip missing matrices + if any(missing in key for missing in missing_matrices): + continue + + if matrix_type == "lora_A": + shape = (rank, hidden_size) + else: # lora_B + shape = (hidden_size, rank) + + weights[key] = torch.randn(shape, dtype=torch.float16) + + # Save to safetensors + safetensors_path = os.path.join(self.temp_dir, + "adapter_model.safetensors") + safetensors_torch.save_file(weights, safetensors_path) + + return self.temp_dir + + def _create_hf_model_config(self): + """Create a LoraModelConfig for HF testing.""" + return self._create_model_config() + + def _create_model_config(self, target_modules=None): + """Create a LoraModelConfig for testing (backward compatibility).""" + if target_modules is None: + target_modules = ['attn_q', 'attn_k', 'attn_v', 'attn_dense'] + + return lora_manager.LoraModelConfig(lora_target_modules=target_modules, + trtllm_modules_to_hf_modules={ + 'attn_q': 'q_proj', + 'attn_k': 'k_proj', + 'attn_v': 'v_proj', + 'attn_dense': 'o_proj' + }, + hidden_size=DEFAULT_HIDDEN_SIZE, + dtype='float16') + + def _create_incomplete_nemo_checkpoint(self, + missing_matrices=None, + include_rank_in_config=True): + """ + Create a NeMo LoRA checkpoint (.nemo archive) for testing. + + Args: + missing_matrices: Dict mapping layer_idx to list of missing matrices + e.g., {0: ['in'], 1: ['out']} + include_rank_in_config: Whether to include adapter_dim in the config + + Returns: + str: Path to the created .nemo file + """ + if missing_matrices is None: + missing_matrices = {} + + # Create model config + model_config = { + "target_modules": ["self_attention.adapter_layer.lora_kqv_adapter"], + "hidden_size": DEFAULT_HIDDEN_SIZE, + "num_layers": DEFAULT_NUM_LAYERS + } + + # Conditionally add lora_tuning with adapter_dim + if include_rank_in_config: + model_config["lora_tuning"] = { + "adapter_dim": DEFAULT_RANK, # This will be used as the rank + "target_modules": ["attention_qkv"], + "alpha": DEFAULT_RANK + } + else: + # Create lora_tuning without adapter_dim to test default rank fallback + model_config["lora_tuning"] = { + "target_modules": ["attention_qkv"], + "alpha": DEFAULT_RANK + } + + # Create model weights + model_weights = {} + rank = DEFAULT_RANK + hidden_size = DEFAULT_HIDDEN_SIZE + + for layer_idx in range(DEFAULT_NUM_LAYERS): + layer_missing = missing_matrices.get(layer_idx, []) + + # Add 'in' matrix unless it's marked as missing + if 'in' not in layer_missing: + key = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight" + model_weights[key] = torch.randn(rank, + hidden_size, + dtype=torch.float16) + + # Add 'out' matrix unless it's marked as missing + if 'out' not in layer_missing: + key = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight" + # NeMo fused QKV is 3x larger + model_weights[key] = torch.randn(3 * hidden_size, + rank, + dtype=torch.float16) + + # Create .nemo archive + nemo_path = os.path.join(self.temp_dir, "test_lora.nemo") + + with tarfile.open(nemo_path, 'w') as tar: + # Add model_config.yaml + config_str = yaml.dump(model_config) + config_info = tarfile.TarInfo('model_config.yaml') + config_info.size = len(config_str.encode()) + tar.addfile(config_info, io.BytesIO(config_str.encode())) + + # Add model_weights.ckpt + weights_buffer = io.BytesIO() + torch.save(model_weights, weights_buffer) + weights_data = weights_buffer.getvalue() + + weights_info = tarfile.TarInfo('model_weights.ckpt') + weights_info.size = len(weights_data) + tar.addfile(weights_info, io.BytesIO(weights_data)) + + return nemo_path + + def _create_nemo_model_config(self): + """Create a LoraModelConfig for NeMo testing.""" + return lora_manager.LoraModelConfig( + lora_target_modules=['attn_qkv'], + trtllm_modules_to_hf_modules={'attn_qkv': 'attn_qkv'}, + hidden_size=DEFAULT_HIDDEN_SIZE, + dtype='float16') + + def test_missing_matrices_graceful_handling(self): + """Test for graceful handling of missing matrices across checkpoint formats.""" + test_cases = [ + # HF test cases + ("hf", ["q_proj.lora_A", + "v_proj.lora_A"], ["q_proj", "v_proj", "missing", "in"]), + ("hf", ["k_proj.lora_B"], ["k_proj", "missing", "out"]), + ("hf", ["q_proj.lora_A", "k_proj.lora_B", + "v_proj.lora_A"], ["missing"]), + # NeMo test cases + ("nemo", { + 0: ["in"] + }, ["Layer 0", "missing", "in"]), + ("nemo", { + 1: ["out"] + }, ["Layer 1", "missing", "out"]), + ("nemo", { + 0: ["in", "out"] + }, ["missing"]), + ] + + for ckpt_source, missing_matrices, expected_warnings in test_cases: + with self.subTest(ckpt_source=ckpt_source, + missing_matrices=missing_matrices): + if ckpt_source == "hf": + checkpoint_path = self._create_incomplete_hf_checkpoint( + missing_matrices) + model_config = self._create_hf_model_config() + else: # nemo + checkpoint_path = self._create_incomplete_nemo_checkpoint( + missing_matrices) + model_config = self._create_nemo_model_config() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + uids = self.lora_manager.load_from_ckpt( + model_dirs_or_files=[checkpoint_path], + model_config=model_config, + runtime_mapping=mapping.Mapping(), + ckpt_source=ckpt_source) + + # Should successfully return UIDs + self.assertEqual(len(uids), 1) + + # Should have generated appropriate warnings + warning_text = ' '.join( + [str(warning.message) for warning in w]) + for expected in expected_warnings: + self.assertIn( + expected, warning_text, + f"Expected '{expected}' in warning text for {ckpt_source} checkpoint" + ) + + def test_complete_checkpoints_no_warnings(self): + """Test that complete checkpoints load without warnings.""" + test_cases = ["hf", "nemo"] + + for ckpt_source in test_cases: + with self.subTest(ckpt_source=ckpt_source): + if ckpt_source == "hf": + checkpoint_path = self._create_incomplete_hf_checkpoint( + []) # No missing matrices + model_config = self._create_hf_model_config() + else: # nemo + checkpoint_path = self._create_incomplete_nemo_checkpoint( + {}) # No missing matrices + model_config = self._create_nemo_model_config() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + uids = self.lora_manager.load_from_ckpt( + model_dirs_or_files=[checkpoint_path], + model_config=model_config, + runtime_mapping=mapping.Mapping(), + ckpt_source=ckpt_source) + + self.assertEqual(len(uids), 1) + + # Should not have any warnings about missing matrices + missing_warnings = [ + warning for warning in w + if 'missing' in str(warning.message) + ] + self.assertEqual( + len(missing_warnings), 0, + f"Complete {ckpt_source} checkpoint should not generate missing matrix warnings" + ) + + +class TestLoraManagerSpecificFeatures(TestLoraManagerBase): + """Tests for specific features that are unique to each checkpoint format.""" + + def test_hf_zero_tensor_dimensions(self): + """Test HF-specific zero tensor dimensions (separate Q/K/V modules).""" + checkpoint_dir = self._create_incomplete_hf_checkpoint( + ['q_proj.lora_A']) + model_config = self._create_model_config(['attn_q' + ]) # Only test one module + + # Mock the zero tensor creation to verify dimensions + original_zeros = torch.zeros + created_tensors = [] + + def mock_zeros(*args, **kwargs): + tensor = original_zeros(*args, **kwargs) + created_tensors.append(tensor.shape) + return tensor + + with mock.patch('torch.zeros', side_effect=mock_zeros): + self.lora_manager.load_from_ckpt( + model_dirs_or_files=[checkpoint_dir], + model_config=model_config, + runtime_mapping=mapping.Mapping(), + ckpt_source='hf') + + # Should have created zero tensors for missing matrices + self.assertGreater( + len(created_tensors), 0, + "Should have created zero tensors for missing matrices") + + # Verify HF tensor dimensions (rank=32, hidden_size=4096 for lora_A) + expected_shape = (DEFAULT_RANK, DEFAULT_HIDDEN_SIZE + ) # lora_A dimensions + self.assertIn( + expected_shape, created_tensors, + f"Expected HF lora_A tensor shape {expected_shape} to be created") + + def test_nemo_zero_tensor_dimensions(self): + """Test NeMo-specific zero tensor dimensions (fused QKV - 3x larger output).""" + # Create checkpoint without rank in config to use default rank (64) + nemo_path = self._create_incomplete_nemo_checkpoint( + {0: ['in', 'out']}, include_rank_in_config=False) + model_config = self._create_nemo_model_config() + + # Mock the zero tensor creation to verify dimensions + original_zeros = torch.zeros + created_tensors = [] + + def mock_zeros(*args, **kwargs): + tensor = original_zeros(*args, **kwargs) + created_tensors.append(tensor.shape) + return tensor + + with mock.patch('torch.zeros', side_effect=mock_zeros): + self.lora_manager.load_from_ckpt(model_dirs_or_files=[nemo_path], + model_config=model_config, + runtime_mapping=mapping.Mapping(), + ckpt_source='nemo') + + # Should have created zero tensors + self.assertGreater( + len(created_tensors), 0, + "Should have created zero tensors for missing matrices") + + # Verify NeMo tensor dimensions (rank=32 from other layers, hidden_size=4096, 3x for fused QKV) + expected_in_shape = (DEFAULT_RANK, DEFAULT_HIDDEN_SIZE + ) # 'in' matrix (lora_A equivalent) + expected_out_shape = (3 * DEFAULT_HIDDEN_SIZE, DEFAULT_RANK + ) # 'out' matrix (3x larger for fused QKV) + + self.assertIn( + expected_in_shape, created_tensors, + f"Expected NeMo 'in' tensor shape {expected_in_shape} to be created" + ) + self.assertIn( + expected_out_shape, created_tensors, + f"Expected NeMo 'out' tensor shape {expected_out_shape} to be created" + ) + + def test_nemo_rank_derivation_from_config_and_tensors(self): + """Test NeMo-specific rank derivation: from config first, then from existing tensors.""" + # Create checkpoint with custom rank where only 'in' is missing + rank = DEFAULT_TEST_RANK + hidden_size = DEFAULT_HIDDEN_SIZE + + # Manually create model weights with custom rank + model_weights = { + f"model.layers.0.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight": + torch.randn(3 * hidden_size, rank, dtype=torch.float16) + } + + # Create .nemo archive + nemo_path = os.path.join(self.temp_dir, "custom_rank.nemo") + model_config_dict = { + "lora_tuning": { + "adapter_dim": rank, # This should be used as primary source + "target_modules": ["attention_qkv"] + }, + "hidden_size": hidden_size + } + + with tarfile.open(nemo_path, 'w') as tar: + # Add config + config_str = yaml.dump(model_config_dict) + config_info = tarfile.TarInfo('model_config.yaml') + config_info.size = len(config_str.encode()) + tar.addfile(config_info, io.BytesIO(config_str.encode())) + + # Add weights + weights_buffer = io.BytesIO() + torch.save(model_weights, weights_buffer) + weights_data = weights_buffer.getvalue() + + weights_info = tarfile.TarInfo('model_weights.ckpt') + weights_info.size = len(weights_data) + tar.addfile(weights_info, io.BytesIO(weights_data)) + + model_config = self._create_nemo_model_config() + + # Mock zero tensor creation to verify correct rank is used + created_tensors = [] + original_zeros = torch.zeros + + def mock_zeros(*args, **kwargs): + tensor = original_zeros(*args, **kwargs) + created_tensors.append(tensor.shape) + return tensor + + with mock.patch('torch.zeros', side_effect=mock_zeros): + self.lora_manager.load_from_ckpt(model_dirs_or_files=[nemo_path], + model_config=model_config, + runtime_mapping=mapping.Mapping(), + ckpt_source='nemo') + + # Should have created 'in' tensor with rank from config (not derived from existing tensor) + expected_in_shape = (rank, hidden_size) + self.assertIn( + expected_in_shape, created_tensors, + f"Expected 'in' tensor with config rank {rank} to be created") + + def test_hf_original_typerror_regression(self): + """Test HF-specific: Ensures original TypeError bug doesn't regress.""" + checkpoint_dir = self._create_incomplete_hf_checkpoint( + ['q_proj.lora_A']) + model_config = self._create_model_config(['attn_q']) + + # This test verifies that the current implementation handles the case gracefully + # Before the fix, this would have raised: TypeError: new(): invalid data type 'str' + try: + uids = self.lora_manager.load_from_ckpt( + model_dirs_or_files=[checkpoint_dir], + model_config=model_config, + runtime_mapping=mapping.Mapping(), + ckpt_source='hf') + # Should succeed with the fix in place + self.assertEqual(len(uids), 1) + except TypeError as e: + if "invalid data type 'str'" in str(e): + self.fail( + "The original TypeError bug has regressed - the fix is not working" + ) + else: + # Some other TypeError, re-raise + raise + + def test_nemo_default_rank_fallback(self): + """Test NeMo-specific: Fallback to default rank when both config and tensors unavailable.""" + # Create checkpoint without rank in config and ALL layers missing matrices to trigger default rank fallback + missing_all_layers = { + i: ['in', 'out'] + for i in range(DEFAULT_NUM_LAYERS) + } + nemo_path = self._create_incomplete_nemo_checkpoint( + missing_all_layers, include_rank_in_config=False) + model_config = self._create_nemo_model_config() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + uids = self.lora_manager.load_from_ckpt( + model_dirs_or_files=[nemo_path], + model_config=model_config, + runtime_mapping=mapping.Mapping(), + ckpt_source='nemo') + + self.assertEqual(len(uids), 1) + + # Should have warnings for both missing matrices AND default rank usage + missing_warnings = [ + warning for warning in w if 'missing' in str(warning.message) + ] + self.assertGreaterEqual( + len(missing_warnings), 2, + "Expected warnings for both missing matrices") + + # Should also have a warning about using default rank + rank_warnings = [ + warning for warning in w + if 'default rank' in str(warning.message) + ] + self.assertGreater(len(rank_warnings), 0, + "Expected warning about using default rank") + + +if __name__ == '__main__': + unittest.main() From 3c141f2defc023d250ad44a95ff1da3314371999 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Tue, 22 Jul 2025 20:24:06 -0700 Subject: [PATCH 15/16] add test to test-list Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_a100.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_lists/test-db/l0_a100.yml b/tests/integration/test_lists/test-db/l0_a100.yml index b8a846ccff6..b1c474b68fa 100644 --- a/tests/integration/test_lists/test-db/l0_a100.yml +++ b/tests/integration/test_lists/test-db/l0_a100.yml @@ -15,6 +15,7 @@ l0_a100: tests: - unittest/llmapi/test_llm_pytorch.py - unittest/llmapi/test_mpi_session.py # generic tests + - unittest/test_lora_manager.py - condition: ranges: system_gpu_count: From bc9f9dd7ec189bf441077b864863e4075ec6a524 Mon Sep 17 00:00:00 2001 From: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> Date: Mon, 28 Jul 2025 17:17:14 -0700 Subject: [PATCH 16/16] reduce scope of changes Signed-off-by: Venky Ganesh <23023424+venkywonka@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 34 +- tests/unittest/test_lora_manager.py | 518 ---------------------------- 2 files changed, 20 insertions(+), 532 deletions(-) delete mode 100644 tests/unittest/test_lora_manager.py diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index d3539d7bfde..db71cb08e81 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -898,19 +898,21 @@ def load_from_model_file(uid, model_file): # Handle missing "in" matrix (lora_A equivalent) if "in" not in layer_weights: - warnings.warn( - f"Layer {layer_idx} is missing 'in' matrix for attn_qkv in NeMo LoRA weights, " - f"creating zero tensor" + raise ValueError( + f"Layer {layer_idx} is missing required 'in' matrix (lora_A equivalent) for attn_qkv " + f"in NeMo LoRA weights from file {model_file}. " + f"LoRA adapters must contain both 'in' and 'out' matrices for all layers. " + f"Please check if the LoRA checkpoint is complete or was corrupted during loading." ) - layer_weights["in"] = torch.zeros(rank, model_config.hidden_size) # Handle missing "out" matrix (lora_B equivalent) - 3x larger for fused QKV if "out" not in layer_weights: - warnings.warn( - f"Layer {layer_idx} is missing 'out' matrix for attn_qkv in NeMo LoRA weights, " - f"creating zero tensor" + raise ValueError( + f"Layer {layer_idx} is missing required 'out' matrix (lora_B equivalent) for attn_qkv " + f"in NeMo LoRA weights from file {model_file}. " + f"LoRA adapters must contain both 'in' and 'out' matrices for all layers. " + f"Please check if the LoRA checkpoint is complete or was corrupted during loading." ) - layer_weights["out"] = torch.zeros(3 * model_config.hidden_size, rank) t_in = layer_weights["in"] t_out = layer_weights["out"] @@ -1125,15 +1127,19 @@ def load_from_model_dir(uid, model_dir, hf_config): t_mag = None else: if "in" not in module_weights: - warnings.warn( - f"Module {hf_module} is missing 'in' matrix, creating zero tensor" + raise ValueError( + f"Module '{hf_module}' in layer {layer_idx} is missing required 'in' matrix (lora_A). " + f"LoRA adapters must contain both 'in' and 'out' matrices for all target modules. " + f"This indicates an incomplete or corrupted LoRA checkpoint in {model_dir}. " + f"Please verify the LoRA adapter was trained and saved correctly." ) - module_weights["in"] = torch.zeros(rank, model_config.hidden_size) if "out" not in module_weights: - warnings.warn( - f"Module {hf_module} is missing 'out' matrix, creating zero tensor" + raise ValueError( + f"Module '{hf_module}' in layer {layer_idx} is missing required 'out' matrix (lora_B). " + f"LoRA adapters must contain both 'in' and 'out' matrices for all target modules. " + f"This indicates an incomplete or corrupted LoRA checkpoint in {model_dir}. " + f"Please verify the LoRA adapter was trained and saved correctly." ) - module_weights["out"] = torch.zeros(model_config.hidden_size, rank) t_in = module_weights["in"] t_out = module_weights["out"] diff --git a/tests/unittest/test_lora_manager.py b/tests/unittest/test_lora_manager.py deleted file mode 100644 index b517361baa9..00000000000 --- a/tests/unittest/test_lora_manager.py +++ /dev/null @@ -1,518 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import json -import os -import tarfile -import tempfile -import unittest -import warnings -from unittest import mock - -import torch -import yaml -from safetensors import torch as safetensors_torch - -# Import the modules to test -from tensorrt_llm import lora_manager, mapping - -# Constants -DEFAULT_HIDDEN_SIZE = 4096 -DEFAULT_RANK = 32 -DEFAULT_NUM_LAYERS = 4 -DEFAULT_TEST_RANK = 16 - - -class TestLoraManagerBase(unittest.TestCase): - """Base class with common functionality for LoRA manager tests.""" - - def setUp(self): - """Set up test fixtures.""" - self.temp_dir = tempfile.mkdtemp() - self.lora_manager = lora_manager.LoraManager() - - def tearDown(self): - """Clean up test fixtures.""" - import shutil - shutil.rmtree(self.temp_dir, ignore_errors=True) - - def _create_incomplete_hf_checkpoint(self, missing_matrices=None): - """ - Create an incomplete HF LoRA checkpoint for testing. - - Args: - missing_matrices: List of matrix types to exclude (e.g., ['q_proj.lora_A']) - If None, defaults to ['q_proj.lora_A', 'v_proj.lora_A'] - - Returns: - str: Path to the created checkpoint directory - """ - if missing_matrices is None: - missing_matrices = ['q_proj.lora_A', 'v_proj.lora_A'] - - # Create adapter_config.json - adapter_config = { - "r": 32, - "lora_alpha": 32, - "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"], - "lora_dropout": 0.05, - "bias": "none", - "task_type": "CAUSAL_LM" - } - - config_path = os.path.join(self.temp_dir, "adapter_config.json") - with open(config_path, 'w') as f: - json.dump(adapter_config, f) - - # Create incomplete weight tensors - weights = {} - hidden_size = DEFAULT_HIDDEN_SIZE - rank = DEFAULT_RANK - num_layers = DEFAULT_NUM_LAYERS # Use fewer layers for faster tests - - for layer_idx in range(num_layers): - layer_prefix = f"base_model.model.model.layers.{layer_idx}.self_attn" - - # Add weights for all modules except missing ones - for module in ["q_proj", "k_proj", "v_proj", "o_proj"]: - for matrix_type in ["lora_A", "lora_B"]: - key = f"{layer_prefix}.{module}.{matrix_type}.weight" - - # Skip missing matrices - if any(missing in key for missing in missing_matrices): - continue - - if matrix_type == "lora_A": - shape = (rank, hidden_size) - else: # lora_B - shape = (hidden_size, rank) - - weights[key] = torch.randn(shape, dtype=torch.float16) - - # Save to safetensors - safetensors_path = os.path.join(self.temp_dir, - "adapter_model.safetensors") - safetensors_torch.save_file(weights, safetensors_path) - - return self.temp_dir - - def _create_hf_model_config(self): - """Create a LoraModelConfig for HF testing.""" - return self._create_model_config() - - def _create_model_config(self, target_modules=None): - """Create a LoraModelConfig for testing (backward compatibility).""" - if target_modules is None: - target_modules = ['attn_q', 'attn_k', 'attn_v', 'attn_dense'] - - return lora_manager.LoraModelConfig(lora_target_modules=target_modules, - trtllm_modules_to_hf_modules={ - 'attn_q': 'q_proj', - 'attn_k': 'k_proj', - 'attn_v': 'v_proj', - 'attn_dense': 'o_proj' - }, - hidden_size=DEFAULT_HIDDEN_SIZE, - dtype='float16') - - def _create_incomplete_nemo_checkpoint(self, - missing_matrices=None, - include_rank_in_config=True): - """ - Create a NeMo LoRA checkpoint (.nemo archive) for testing. - - Args: - missing_matrices: Dict mapping layer_idx to list of missing matrices - e.g., {0: ['in'], 1: ['out']} - include_rank_in_config: Whether to include adapter_dim in the config - - Returns: - str: Path to the created .nemo file - """ - if missing_matrices is None: - missing_matrices = {} - - # Create model config - model_config = { - "target_modules": ["self_attention.adapter_layer.lora_kqv_adapter"], - "hidden_size": DEFAULT_HIDDEN_SIZE, - "num_layers": DEFAULT_NUM_LAYERS - } - - # Conditionally add lora_tuning with adapter_dim - if include_rank_in_config: - model_config["lora_tuning"] = { - "adapter_dim": DEFAULT_RANK, # This will be used as the rank - "target_modules": ["attention_qkv"], - "alpha": DEFAULT_RANK - } - else: - # Create lora_tuning without adapter_dim to test default rank fallback - model_config["lora_tuning"] = { - "target_modules": ["attention_qkv"], - "alpha": DEFAULT_RANK - } - - # Create model weights - model_weights = {} - rank = DEFAULT_RANK - hidden_size = DEFAULT_HIDDEN_SIZE - - for layer_idx in range(DEFAULT_NUM_LAYERS): - layer_missing = missing_matrices.get(layer_idx, []) - - # Add 'in' matrix unless it's marked as missing - if 'in' not in layer_missing: - key = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter.linear_in.weight" - model_weights[key] = torch.randn(rank, - hidden_size, - dtype=torch.float16) - - # Add 'out' matrix unless it's marked as missing - if 'out' not in layer_missing: - key = f"model.layers.{layer_idx}.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight" - # NeMo fused QKV is 3x larger - model_weights[key] = torch.randn(3 * hidden_size, - rank, - dtype=torch.float16) - - # Create .nemo archive - nemo_path = os.path.join(self.temp_dir, "test_lora.nemo") - - with tarfile.open(nemo_path, 'w') as tar: - # Add model_config.yaml - config_str = yaml.dump(model_config) - config_info = tarfile.TarInfo('model_config.yaml') - config_info.size = len(config_str.encode()) - tar.addfile(config_info, io.BytesIO(config_str.encode())) - - # Add model_weights.ckpt - weights_buffer = io.BytesIO() - torch.save(model_weights, weights_buffer) - weights_data = weights_buffer.getvalue() - - weights_info = tarfile.TarInfo('model_weights.ckpt') - weights_info.size = len(weights_data) - tar.addfile(weights_info, io.BytesIO(weights_data)) - - return nemo_path - - def _create_nemo_model_config(self): - """Create a LoraModelConfig for NeMo testing.""" - return lora_manager.LoraModelConfig( - lora_target_modules=['attn_qkv'], - trtllm_modules_to_hf_modules={'attn_qkv': 'attn_qkv'}, - hidden_size=DEFAULT_HIDDEN_SIZE, - dtype='float16') - - def test_missing_matrices_graceful_handling(self): - """Test for graceful handling of missing matrices across checkpoint formats.""" - test_cases = [ - # HF test cases - ("hf", ["q_proj.lora_A", - "v_proj.lora_A"], ["q_proj", "v_proj", "missing", "in"]), - ("hf", ["k_proj.lora_B"], ["k_proj", "missing", "out"]), - ("hf", ["q_proj.lora_A", "k_proj.lora_B", - "v_proj.lora_A"], ["missing"]), - # NeMo test cases - ("nemo", { - 0: ["in"] - }, ["Layer 0", "missing", "in"]), - ("nemo", { - 1: ["out"] - }, ["Layer 1", "missing", "out"]), - ("nemo", { - 0: ["in", "out"] - }, ["missing"]), - ] - - for ckpt_source, missing_matrices, expected_warnings in test_cases: - with self.subTest(ckpt_source=ckpt_source, - missing_matrices=missing_matrices): - if ckpt_source == "hf": - checkpoint_path = self._create_incomplete_hf_checkpoint( - missing_matrices) - model_config = self._create_hf_model_config() - else: # nemo - checkpoint_path = self._create_incomplete_nemo_checkpoint( - missing_matrices) - model_config = self._create_nemo_model_config() - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - uids = self.lora_manager.load_from_ckpt( - model_dirs_or_files=[checkpoint_path], - model_config=model_config, - runtime_mapping=mapping.Mapping(), - ckpt_source=ckpt_source) - - # Should successfully return UIDs - self.assertEqual(len(uids), 1) - - # Should have generated appropriate warnings - warning_text = ' '.join( - [str(warning.message) for warning in w]) - for expected in expected_warnings: - self.assertIn( - expected, warning_text, - f"Expected '{expected}' in warning text for {ckpt_source} checkpoint" - ) - - def test_complete_checkpoints_no_warnings(self): - """Test that complete checkpoints load without warnings.""" - test_cases = ["hf", "nemo"] - - for ckpt_source in test_cases: - with self.subTest(ckpt_source=ckpt_source): - if ckpt_source == "hf": - checkpoint_path = self._create_incomplete_hf_checkpoint( - []) # No missing matrices - model_config = self._create_hf_model_config() - else: # nemo - checkpoint_path = self._create_incomplete_nemo_checkpoint( - {}) # No missing matrices - model_config = self._create_nemo_model_config() - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - uids = self.lora_manager.load_from_ckpt( - model_dirs_or_files=[checkpoint_path], - model_config=model_config, - runtime_mapping=mapping.Mapping(), - ckpt_source=ckpt_source) - - self.assertEqual(len(uids), 1) - - # Should not have any warnings about missing matrices - missing_warnings = [ - warning for warning in w - if 'missing' in str(warning.message) - ] - self.assertEqual( - len(missing_warnings), 0, - f"Complete {ckpt_source} checkpoint should not generate missing matrix warnings" - ) - - -class TestLoraManagerSpecificFeatures(TestLoraManagerBase): - """Tests for specific features that are unique to each checkpoint format.""" - - def test_hf_zero_tensor_dimensions(self): - """Test HF-specific zero tensor dimensions (separate Q/K/V modules).""" - checkpoint_dir = self._create_incomplete_hf_checkpoint( - ['q_proj.lora_A']) - model_config = self._create_model_config(['attn_q' - ]) # Only test one module - - # Mock the zero tensor creation to verify dimensions - original_zeros = torch.zeros - created_tensors = [] - - def mock_zeros(*args, **kwargs): - tensor = original_zeros(*args, **kwargs) - created_tensors.append(tensor.shape) - return tensor - - with mock.patch('torch.zeros', side_effect=mock_zeros): - self.lora_manager.load_from_ckpt( - model_dirs_or_files=[checkpoint_dir], - model_config=model_config, - runtime_mapping=mapping.Mapping(), - ckpt_source='hf') - - # Should have created zero tensors for missing matrices - self.assertGreater( - len(created_tensors), 0, - "Should have created zero tensors for missing matrices") - - # Verify HF tensor dimensions (rank=32, hidden_size=4096 for lora_A) - expected_shape = (DEFAULT_RANK, DEFAULT_HIDDEN_SIZE - ) # lora_A dimensions - self.assertIn( - expected_shape, created_tensors, - f"Expected HF lora_A tensor shape {expected_shape} to be created") - - def test_nemo_zero_tensor_dimensions(self): - """Test NeMo-specific zero tensor dimensions (fused QKV - 3x larger output).""" - # Create checkpoint without rank in config to use default rank (64) - nemo_path = self._create_incomplete_nemo_checkpoint( - {0: ['in', 'out']}, include_rank_in_config=False) - model_config = self._create_nemo_model_config() - - # Mock the zero tensor creation to verify dimensions - original_zeros = torch.zeros - created_tensors = [] - - def mock_zeros(*args, **kwargs): - tensor = original_zeros(*args, **kwargs) - created_tensors.append(tensor.shape) - return tensor - - with mock.patch('torch.zeros', side_effect=mock_zeros): - self.lora_manager.load_from_ckpt(model_dirs_or_files=[nemo_path], - model_config=model_config, - runtime_mapping=mapping.Mapping(), - ckpt_source='nemo') - - # Should have created zero tensors - self.assertGreater( - len(created_tensors), 0, - "Should have created zero tensors for missing matrices") - - # Verify NeMo tensor dimensions (rank=32 from other layers, hidden_size=4096, 3x for fused QKV) - expected_in_shape = (DEFAULT_RANK, DEFAULT_HIDDEN_SIZE - ) # 'in' matrix (lora_A equivalent) - expected_out_shape = (3 * DEFAULT_HIDDEN_SIZE, DEFAULT_RANK - ) # 'out' matrix (3x larger for fused QKV) - - self.assertIn( - expected_in_shape, created_tensors, - f"Expected NeMo 'in' tensor shape {expected_in_shape} to be created" - ) - self.assertIn( - expected_out_shape, created_tensors, - f"Expected NeMo 'out' tensor shape {expected_out_shape} to be created" - ) - - def test_nemo_rank_derivation_from_config_and_tensors(self): - """Test NeMo-specific rank derivation: from config first, then from existing tensors.""" - # Create checkpoint with custom rank where only 'in' is missing - rank = DEFAULT_TEST_RANK - hidden_size = DEFAULT_HIDDEN_SIZE - - # Manually create model weights with custom rank - model_weights = { - f"model.layers.0.self_attention.adapter_layer.lora_kqv_adapter.linear_out.weight": - torch.randn(3 * hidden_size, rank, dtype=torch.float16) - } - - # Create .nemo archive - nemo_path = os.path.join(self.temp_dir, "custom_rank.nemo") - model_config_dict = { - "lora_tuning": { - "adapter_dim": rank, # This should be used as primary source - "target_modules": ["attention_qkv"] - }, - "hidden_size": hidden_size - } - - with tarfile.open(nemo_path, 'w') as tar: - # Add config - config_str = yaml.dump(model_config_dict) - config_info = tarfile.TarInfo('model_config.yaml') - config_info.size = len(config_str.encode()) - tar.addfile(config_info, io.BytesIO(config_str.encode())) - - # Add weights - weights_buffer = io.BytesIO() - torch.save(model_weights, weights_buffer) - weights_data = weights_buffer.getvalue() - - weights_info = tarfile.TarInfo('model_weights.ckpt') - weights_info.size = len(weights_data) - tar.addfile(weights_info, io.BytesIO(weights_data)) - - model_config = self._create_nemo_model_config() - - # Mock zero tensor creation to verify correct rank is used - created_tensors = [] - original_zeros = torch.zeros - - def mock_zeros(*args, **kwargs): - tensor = original_zeros(*args, **kwargs) - created_tensors.append(tensor.shape) - return tensor - - with mock.patch('torch.zeros', side_effect=mock_zeros): - self.lora_manager.load_from_ckpt(model_dirs_or_files=[nemo_path], - model_config=model_config, - runtime_mapping=mapping.Mapping(), - ckpt_source='nemo') - - # Should have created 'in' tensor with rank from config (not derived from existing tensor) - expected_in_shape = (rank, hidden_size) - self.assertIn( - expected_in_shape, created_tensors, - f"Expected 'in' tensor with config rank {rank} to be created") - - def test_hf_original_typerror_regression(self): - """Test HF-specific: Ensures original TypeError bug doesn't regress.""" - checkpoint_dir = self._create_incomplete_hf_checkpoint( - ['q_proj.lora_A']) - model_config = self._create_model_config(['attn_q']) - - # This test verifies that the current implementation handles the case gracefully - # Before the fix, this would have raised: TypeError: new(): invalid data type 'str' - try: - uids = self.lora_manager.load_from_ckpt( - model_dirs_or_files=[checkpoint_dir], - model_config=model_config, - runtime_mapping=mapping.Mapping(), - ckpt_source='hf') - # Should succeed with the fix in place - self.assertEqual(len(uids), 1) - except TypeError as e: - if "invalid data type 'str'" in str(e): - self.fail( - "The original TypeError bug has regressed - the fix is not working" - ) - else: - # Some other TypeError, re-raise - raise - - def test_nemo_default_rank_fallback(self): - """Test NeMo-specific: Fallback to default rank when both config and tensors unavailable.""" - # Create checkpoint without rank in config and ALL layers missing matrices to trigger default rank fallback - missing_all_layers = { - i: ['in', 'out'] - for i in range(DEFAULT_NUM_LAYERS) - } - nemo_path = self._create_incomplete_nemo_checkpoint( - missing_all_layers, include_rank_in_config=False) - model_config = self._create_nemo_model_config() - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - uids = self.lora_manager.load_from_ckpt( - model_dirs_or_files=[nemo_path], - model_config=model_config, - runtime_mapping=mapping.Mapping(), - ckpt_source='nemo') - - self.assertEqual(len(uids), 1) - - # Should have warnings for both missing matrices AND default rank usage - missing_warnings = [ - warning for warning in w if 'missing' in str(warning.message) - ] - self.assertGreaterEqual( - len(missing_warnings), 2, - "Expected warnings for both missing matrices") - - # Should also have a warning about using default rank - rank_warnings = [ - warning for warning in w - if 'default rank' in str(warning.message) - ] - self.assertGreater(len(rank_warnings), 0, - "Expected warning about using default rank") - - -if __name__ == '__main__': - unittest.main()