Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,15 +523,23 @@ def get_bindings_model_config(self,

# For kv cache size calculation: set size_per_head
head_dim_names = ["head_size", "head_dim"]
head_size = None
for head_dim_name in head_dim_names:
if head_dim_name in self.pretrained_config:
head_size = getattr(self.pretrained_config, head_dim_name)
break
else:
if hasattr(self.pretrained_config, head_dim_name):
value = getattr(self.pretrained_config, head_dim_name)
if value is not None:
head_size = value
break

if head_size is None:
assert hidden_size % num_heads == 0, (
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})"
)
calculated_head_size = hidden_size // num_heads
logger.warning(
f"head_size/head_dim is not set, using default value {hidden_size // num_heads}"
f"head_size/head_dim is not set or None, using default value {calculated_head_size}"
)
head_size = hidden_size // num_heads
head_size = calculated_head_size

model_config_cpp.mlp_hidden_size = mlp_hidden_size
model_config_cpp.size_per_head = head_size
Expand Down
97 changes: 34 additions & 63 deletions tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
prompts, run_llm_abort_request,
run_llm_with_postprocess_parallel_and_result_handler,
tinyllama_logits_processor_test_harness)
from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
from utils.util import (force_ampere, similar, skip_fp8_pre_ada,
skip_gpu_memory_less_than_40gb,
skip_gpu_memory_less_than_80gb,
skip_gpu_memory_less_than_138gb)
from utils.llm_data import llm_models_root
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
import tempfile

import torch
Expand Down Expand Up @@ -496,68 +495,36 @@ def test_nemotron_nas_lora() -> None:


@skip_gpu_memory_less_than_80gb
@pytest.mark.skip(reason="https://nvbugs/5521949")
def test_codellama_fp8_with_bf16_lora() -> None:
model_dir = f"{llm_models_root()}/codellama/CodeLlama-7b-Instruct-hf/"
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8)

target_modules = ['attn_q', 'attn_k', 'attn_v']

# Set up temporary directory for LoRA adapters
with tempfile.TemporaryDirectory() as lora_dir:
print("Creating dummy LoRAs...")

model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)

hf_modules = ["q_proj", "k_proj", "v_proj"]

lora_config = PeftLoraConfig(r=8,
target_modules=hf_modules,
bias="none",
task_type="CAUSAL_LM")

lora_paths = []
for i in range(2):
lora_model = get_peft_model(model, lora_config)
for param in lora_model.parameters():
param.data.zero_()
lora_path = f"{lora_dir}/lora_{i}"
lora_model.save_pretrained(lora_path)
lora_paths.append(lora_path)

lora_config = LoraConfig(lora_dir=lora_paths,
lora_target_modules=target_modules,
max_lora_rank=8,
max_loras=2,
max_cpu_loras=2)

llm = LLM(model_dir, quant_config=quant_config, lora_config=lora_config)

prompts = [
"Write a function that calculates the Fibonacci sequence.",
"Convert this C++ code to Python: int x = 0; x++;",
]

lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
lora_requests = [lora_req1, lora_req2]
sampling_params = SamplingParams(max_tokens=200)
def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
skip_fp8_pre_ada(use_fp8=True)
model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
lora_dir = f"{llm_models_root()}/lora/llama-3-chinese-8b-instruct-v2-lora"
prompt = "美国的首都是哪里?"
reference = "华盛顿特区。华盛顿特区是美国的首都和一个行政区"

lora_config = LoraConfig(lora_dir=[lora_dir],
max_lora_rank=64,
max_loras=2,
max_cpu_loras=2)
lora_req = LoRARequest("lora-chinese", 0, lora_dir)

outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests)
llm = LLM(
model_dir,
lora_config=lora_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)

assert len(outputs) == 2
try:
output = llm.generate(prompt,
SamplingParams(max_tokens=20),
lora_request=[lora_req])
finally:
llm.shutdown()
assert similar(output.outputs[0].text, reference)


@skip_gpu_memory_less_than_80gb
@pytest.mark.skip(reason="https://nvbugs/5521949")
def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
model_dir = f"{llm_models_root()}/Bielik-11B-v2.2-Instruct"

Expand All @@ -584,12 +551,16 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
lora_model.save_pretrained(lora_path)
lora_paths.append(lora_path)

trtllm_lora_config = LoraConfig(lora_dir=lora_paths,
lora_target_modules=target_modules,
trtllm_lora_config = LoraConfig(lora_target_modules=target_modules,
max_lora_rank=8,
max_loras=2,
max_cpu_loras=2)
llm = LLM(model_dir, lora_config=trtllm_lora_config)
llm = LLM(
model_dir,
lora_config=trtllm_lora_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)

prompts = [
"Kim był Mikołaj Kopernik i z czego zasłynął?",
Expand Down