Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,15 +525,23 @@ def get_bindings_model_config(self,

# For kv cache size calculation: set size_per_head
head_dim_names = ["head_size", "head_dim"]
head_size = None
for head_dim_name in head_dim_names:
if head_dim_name in self.pretrained_config:
head_size = getattr(self.pretrained_config, head_dim_name)
break
else:
if hasattr(self.pretrained_config, head_dim_name):
value = getattr(self.pretrained_config, head_dim_name)
if value is not None:
head_size = value
break

if head_size is None:
assert hidden_size % num_heads == 0, (
f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})"
)
calculated_head_size = hidden_size // num_heads
logger.warning(
f"head_size/head_dim is not set, using default value {hidden_size // num_heads}"
f"head_size/head_dim is not set or None, using default value {calculated_head_size}"
)
head_size = hidden_size // num_heads
head_size = calculated_head_size

model_config_cpp.mlp_hidden_size = mlp_hidden_size
model_config_cpp.size_per_head = head_size
Expand Down
207 changes: 183 additions & 24 deletions tests/integration/defs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

from packaging import version

from tensorrt_llm import LLM as LLM_torch
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.sampling_params import SamplingParams

from .trt_test_alternative import check_call, check_output, exists, is_windows


Expand Down Expand Up @@ -739,12 +744,28 @@ def generate_dummy_loras(
from transformers import AutoModelForCausalLM

print("Creating pseudo LoRAs...")
model = AutoModelForCausalLM.from_pretrained(
hf_model_dir,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)

# Avoid meta tensors by loading model to CPU first (ensures all parameters are materialized)
try:
model = AutoModelForCausalLM.from_pretrained(
hf_model_dir,
torch_dtype=torch.float16,
device_map=None, # Load everything to CPU first
trust_remote_code=True,
low_cpu_mem_usage=False,
)
except Exception:
# Fallback to auto device mapping if CPU loading fails
print(
"Warning: Loading model to CPU failed, falling back to auto device mapping"
)
model = AutoModelForCausalLM.from_pretrained(
hf_model_dir,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)

lora_config = LoraConfig(r=lora_rank,
target_modules=target_modules,
bias="none",
Expand All @@ -755,12 +776,57 @@ def generate_dummy_loras(
if zero_weights:
for param in lora_model.parameters():
param.data.zero_()

pseudo_lora_dir = f"{lora_output_dir}/pseudo_lora_{lora_idx}"
lora_model.save_pretrained(pseudo_lora_dir)
lora_output_paths.append(pseudo_lora_dir)
return lora_output_paths


def get_test_prompts(use_code_prompts: bool = False) -> list[str]:
"""Get test prompts for LoRA testing.

Args:
use_code_prompts: If True, return code-related prompts. If False, return general prompts.

Returns:
List of test prompts.
"""
if use_code_prompts:
return [
"Write a function that outputs the fibonacci sequence.",
"Convert the following C++ code to Python: x = 0;x++;",
"Find the largest prime factor of 42.",
"write a unit test for this function: $(cat fib.py)",
"# A simple python function to remove whitespace from a string:",
"How to load CodeLlama from HuggingFace?",
]
else:
return [
"Hey how are you doing today?",
"How is the weather in Seattle, WA?",
"Is it ok to fill diesel in a petrol car?",
"Can you check the top 5 trending songs on spotify?",
"What is the capital of France?",
"How to load CodeLlama from HuggingFace?",
]


def get_test_prompts_for_torch() -> list[str]:
"""Get test prompts for LoRA Torch testing.

Returns:
List of test prompts.
"""
return [
"Hey how are you doing today?",
"How is the weather in Seattle, WA?",
"Is it ok to fill diesel in a petrol car?",
"Can you check the top 5 trending songs on spotify?",
"What is the capital of France?",
]


def test_multi_lora_support(
hf_model_dir,
tllm_ckpt_dir,
Expand Down Expand Up @@ -815,24 +881,7 @@ def test_multi_lora_support(
print(
f"Build engines completed in {(build_end - build_start):.2f} seconds.")

if use_code_prompts:
input_prompts = [
"Write a function that outputs the fibonacci sequence.",
"Convert the following C++ code to Python: x = 0;x++;",
"Find the largest prime factor of 42.",
"write a unit test for this function: $(cat fib.py)",
"# A simple python function to remove whitespace from a string:",
"How to load CodeLlama from HuggingFace?",
]
else:
input_prompts = [
"Hey how are you doing today?",
"How is the weather in Seattle, WA?",
"Is it ok to fill diesel in a petrol car?",
"Can you check the top 5 trending songs on spotify?",
"What is the capital of France?",
"How to load CodeLlama from HuggingFace?",
]
input_prompts = get_test_prompts(use_code_prompts)

print("Run inference with C++ runtime with pybind...")
inference_start = time.time()
Expand Down Expand Up @@ -867,6 +916,116 @@ def test_multi_lora_support(
)


def test_llm_torch_multi_lora_support(
hf_model_dir,
llm_venv,
num_loras=2,
lora_rank=8,
target_hf_modules=["q_proj", "k_proj", "v_proj"],
target_trtllm_modules=["attn_q", "attn_k", "attn_v"],
zero_lora_weights=True,
tensor_parallel_size=1,
pipeline_parallel_size=1,
expected_outputs=None):
"""Test multi-LoRA support with LLM-API Torch backend."""

# if expected_outputs is None:
# raise ValueError("expected_outputs must be provided for exact validation")

start_time = time.time()
print("Creating dummy LoRAs...")
lora_start = time.time()

lora_paths = generate_dummy_loras(
hf_model_dir=hf_model_dir,
lora_output_dir=llm_venv.get_working_directory(),
num_loras=num_loras,
lora_rank=lora_rank,
target_modules=target_hf_modules,
zero_weights=zero_lora_weights)
lora_end = time.time()
print(
f"Creating dummy LoRAs completed in {(lora_end - lora_start):.2f} seconds."
)

print("Initializing LLM_torch with LoRA support...")
init_start = time.time()

lora_config = LoraConfig(lora_dir=lora_paths,
max_lora_rank=lora_rank,
max_loras=num_loras,
max_cpu_loras=num_loras,
lora_target_modules=target_trtllm_modules)

input_prompts = get_test_prompts_for_torch()

with LLM_torch(
model=hf_model_dir,
lora_config=lora_config,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
dtype="bfloat16",
max_batch_size=8, # From original test
max_input_len=512, # From original test
max_seq_len=562, # From original test
max_beam_width=1 # From original test
) as llm:

init_end = time.time()
print(
f"LLM_torch initialization completed in {(init_end - init_start):.2f} seconds."
)

print("Running inference with LLM-API Torch backend...")
inference_start = time.time()

# Create LoRA requests for different adapters
lora_requests = []
for i in range(len(input_prompts)):
if i % 2 == 1: # Add some requests without LoRA
lora_requests.append(None)
else: # With LoRA
lora_requests.append(
LoRARequest(f"lora-{i}", i,
lora_paths[i % len(lora_paths)]))

sampling_params = SamplingParams(max_tokens=30,
top_p=0.5,
top_k=0,
temperature=0.0)

outputs = llm.generate(input_prompts,
sampling_params=sampling_params,
lora_request=lora_requests)

inference_end = time.time()
print(
f"Inference completed in {(inference_end - inference_start):.2f} seconds."
)

# Validate exact outputs
print("Validating exact outputs...")
assert len(outputs) == len(expected_outputs), \
f"Expected {len(expected_outputs)} outputs, got {len(outputs)}"

for i, (output, expected) in enumerate(zip(outputs, expected_outputs)):
actual_text = output.outputs[0].text
print(f"Prompt {i+1}: {input_prompts[i]}")
print(
f"LoRA: {lora_requests[i].lora_int_id if lora_requests[i] else 'None'}"
)
print(f"Expected: {expected}")
print(f"Actual: {actual_text}")
print("-" * 50)

# Exact string comparison
assert actual_text == expected, \
f"Output {i+1} mismatch:\nExpected: {expected!r}\nActual: {actual_text!r}"

total_time = time.time() - start_time
print(f"Total test execution time: {total_time:.2f} seconds")


def get_dummy_spec_decoding_heads(hf_model_dir,
save_dir,
mode='medusa',
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/defs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,9 @@ def llama_model_root(request):
elif request.param == "llama-3.1-8b-instruct-hf-fp8":
llama_model_root = os.path.join(models_root, "llama-3.1-model",
"Llama-3.1-8B-Instruct-FP8")
elif request.param == "llama-3.1-8b-instruct":
llama_model_root = os.path.join(models_root, "llama-3.1-model",
"Llama-3.1-8B-Instruct")
elif request.param == "llama-3.1-8b-hf-nvfp4":
llama_model_root = os.path.join(models_root, "nvfp4-quantized",
"Meta-Llama-3.1-8B")
Expand All @@ -1024,9 +1027,18 @@ def llama_model_root(request):
elif request.param == "llama-3.2-1b":
llama_model_root = os.path.join(models_root, "llama-3.2-models",
"Llama-3.2-1B")
elif request.param == "llama-3.2-1b-instruct":
llama_model_root = os.path.join(models_root, "llama-3.2-models",
"Llama-3.2-1B-Instruct")
elif request.param == "llama-3.2-3b":
llama_model_root = os.path.join(models_root, "llama-3.2-models",
"Llama-3.2-3B")
elif request.param == "llama-3.2-3b-instruct":
llama_model_root = os.path.join(models_root, "llama-3.2-models",
"Llama-3.2-3B-Instruct")
elif request.param == "llama-3.3-70b-instruct":
llama_model_root = os.path.join(models_root, "llama-3.3-models",
"Llama-3.3-70B-Instruct")
assert os.path.exists(
llama_model_root
), f"{llama_model_root} does not exist under NFS LLM_MODELS_ROOT dir"
Expand Down Expand Up @@ -1323,6 +1335,11 @@ def llm_lora_model_root(request):
elif item == "komt-mistral-7b-v1-lora":
model_root_list.append(
os.path.join(models_root, "komt-mistral-7b-v1-lora"))
elif item == "Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_NIM_r32":
model_root_list.append(
os.path.join(
models_root, "nemotron-nas",
"Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_NIM_r32"))

return ",".join(model_root_list)

Expand Down Expand Up @@ -1363,6 +1380,8 @@ def llm_mistral_model_root(request):
model_root = os.path.join(models_root, "mistral-7b-v0.1")
if request.param == "mistral-7b-v0.1":
model_root = os.path.join(models_root, "mistral-7b-v0.1")
if request.param == "mistral-nemo-instruct-2407":
model_root = os.path.join(models_root, "Mistral-Nemo-Instruct-2407")
if request.param == "komt-mistral-7b-v1":
model_root = os.path.join(models_root, "komt-mistral-7b-v1")
if request.param == "mistral-7b-v0.3":
Expand Down
Loading