diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 0703a442007..045abaf6bbc 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -523,15 +523,23 @@ def get_bindings_model_config(self, # For kv cache size calculation: set size_per_head head_dim_names = ["head_size", "head_dim"] + head_size = None for head_dim_name in head_dim_names: - if head_dim_name in self.pretrained_config: - head_size = getattr(self.pretrained_config, head_dim_name) - break - else: + if hasattr(self.pretrained_config, head_dim_name): + value = getattr(self.pretrained_config, head_dim_name) + if value is not None: + head_size = value + break + + if head_size is None: + assert hidden_size % num_heads == 0, ( + f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})" + ) + calculated_head_size = hidden_size // num_heads logger.warning( - f"head_size/head_dim is not set, using default value {hidden_size // num_heads}" + f"head_size/head_dim is not set or None, using default value {calculated_head_size}" ) - head_size = hidden_size // num_heads + head_size = calculated_head_size model_config_cpp.mlp_hidden_size = mlp_hidden_size model_config_cpp.size_per_head = head_size diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 7ee2dbbaa7f..114d0ef743d 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -25,14 +25,13 @@ prompts, run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler, tinyllama_logits_processor_test_harness) -from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb, +from utils.util import (force_ampere, similar, skip_fp8_pre_ada, + skip_gpu_memory_less_than_40gb, skip_gpu_memory_less_than_80gb, skip_gpu_memory_less_than_138gb) from utils.llm_data import llm_models_root from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.executor.request import LoRARequest -from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo import tempfile import torch @@ -496,68 +495,36 @@ def test_nemotron_nas_lora() -> None: @skip_gpu_memory_less_than_80gb -@pytest.mark.skip(reason="https://nvbugs/5521949") -def test_codellama_fp8_with_bf16_lora() -> None: - model_dir = f"{llm_models_root()}/codellama/CodeLlama-7b-Instruct-hf/" - quant_config = QuantConfig(quant_algo=QuantAlgo.FP8, - kv_cache_quant_algo=QuantAlgo.FP8) - - target_modules = ['attn_q', 'attn_k', 'attn_v'] - - # Set up temporary directory for LoRA adapters - with tempfile.TemporaryDirectory() as lora_dir: - print("Creating dummy LoRAs...") - - model = AutoModelForCausalLM.from_pretrained( - model_dir, - torch_dtype=torch.bfloat16, - device_map="auto", - trust_remote_code=True, - ) - - hf_modules = ["q_proj", "k_proj", "v_proj"] - - lora_config = PeftLoraConfig(r=8, - target_modules=hf_modules, - bias="none", - task_type="CAUSAL_LM") - - lora_paths = [] - for i in range(2): - lora_model = get_peft_model(model, lora_config) - for param in lora_model.parameters(): - param.data.zero_() - lora_path = f"{lora_dir}/lora_{i}" - lora_model.save_pretrained(lora_path) - lora_paths.append(lora_path) - - lora_config = LoraConfig(lora_dir=lora_paths, - lora_target_modules=target_modules, - max_lora_rank=8, - max_loras=2, - max_cpu_loras=2) - - llm = LLM(model_dir, quant_config=quant_config, lora_config=lora_config) - - prompts = [ - "Write a function that calculates the Fibonacci sequence.", - "Convert this C++ code to Python: int x = 0; x++;", - ] - - lora_req1 = LoRARequest("lora-1", 0, lora_paths[0]) - lora_req2 = LoRARequest("lora-2", 1, lora_paths[1]) - lora_requests = [lora_req1, lora_req2] - sampling_params = SamplingParams(max_tokens=200) +def test_llama_3_1_8b_fp8_with_bf16_lora() -> None: + skip_fp8_pre_ada(use_fp8=True) + model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8" + lora_dir = f"{llm_models_root()}/lora/llama-3-chinese-8b-instruct-v2-lora" + prompt = "美国的首都是哪里?" + reference = "华盛顿特区。华盛顿特区是美国的首都和一个行政区" + + lora_config = LoraConfig(lora_dir=[lora_dir], + max_lora_rank=64, + max_loras=2, + max_cpu_loras=2) + lora_req = LoRARequest("lora-chinese", 0, lora_dir) - outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_requests) + llm = LLM( + model_dir, + lora_config=lora_config, + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + cuda_graph_config=None) - assert len(outputs) == 2 + try: + output = llm.generate(prompt, + SamplingParams(max_tokens=20), + lora_request=[lora_req]) + finally: + llm.shutdown() + assert similar(output.outputs[0].text, reference) @skip_gpu_memory_less_than_80gb -@pytest.mark.skip(reason="https://nvbugs/5521949") def test_bielik_11b_v2_2_instruct_multi_lora() -> None: model_dir = f"{llm_models_root()}/Bielik-11B-v2.2-Instruct" @@ -584,12 +551,16 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: lora_model.save_pretrained(lora_path) lora_paths.append(lora_path) - trtllm_lora_config = LoraConfig(lora_dir=lora_paths, - lora_target_modules=target_modules, + trtllm_lora_config = LoraConfig(lora_target_modules=target_modules, max_lora_rank=8, max_loras=2, max_cpu_loras=2) - llm = LLM(model_dir, lora_config=trtllm_lora_config) + llm = LLM( + model_dir, + lora_config=trtllm_lora_config, + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + cuda_graph_config=None) prompts = [ "Kim był Mikołaj Kopernik i z czego zasłynął?",