Skip to content

Commit 927fdff

Browse files
committed
save initial changes
1 parent 16febef commit 927fdff

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def forward(
117117
attn_metadata, FlashInferAttentionMetadata
118118
), "Only FlashInfer backend supports custom attention mask currently."
119119
assert attention_mask == CustomAttentionMask.CUSTOM
120+
print("lora_params: ", lora_params)
120121
return super().forward(position_ids=position_ids,
121122
hidden_states=hidden_states,
122123
attn_metadata=attn_metadata,

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from tensorrt_llm import LLM
4+
from tensorrt_llm.llmapi import KvCacheConfig
45
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
56
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
67
from tensorrt_llm.sampling_params import SamplingParams
@@ -492,6 +493,59 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
492493
assert len(outputs) == 2
493494

494495

496+
def test_gemma3_1b_instruct_multi_lora() -> None:
497+
model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it"
498+
499+
target_modules = ['attn_q', 'attn_k', 'attn_v']
500+
501+
# Set up temporary directory for LoRA adapters
502+
with tempfile.TemporaryDirectory() as lora_dir:
503+
print("Creating dummy LoRAs...")
504+
505+
model = AutoModelForCausalLM.from_pretrained(model_dir,
506+
torch_dtype=torch.bfloat16,
507+
device_map="auto")
508+
hf_modules = ["q_proj", "k_proj", "v_proj"]
509+
peft_lora_config = PeftLoraConfig(r=8,
510+
target_modules=hf_modules,
511+
bias="none",
512+
task_type="CAUSAL_LM")
513+
lora_paths = []
514+
for i in range(2):
515+
lora_model = get_peft_model(model, peft_lora_config)
516+
for param in lora_model.parameters():
517+
param.data.zero_()
518+
lora_path = f"{lora_dir}/lora_{i}"
519+
lora_model.save_pretrained(lora_path)
520+
lora_paths.append(lora_path)
521+
522+
trtllm_lora_config = LoraConfig(lora_dir=lora_paths,
523+
lora_target_modules=target_modules,
524+
max_lora_rank=8,
525+
max_loras=2,
526+
max_cpu_loras=2)
527+
kv_cache_config = KvCacheConfig(
528+
enable_block_reuse=False,
529+
enable_partial_reuse=False,
530+
)
531+
llm = LLM(model_dir, lora_config=trtllm_lora_config, kv_cache_config=kv_cache_config)
532+
533+
prompts = [
534+
"Is it ok to fill diesel in a petrol car?",
535+
"What is the capital of France?",
536+
]
537+
lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
538+
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
539+
lora_requests = [lora_req1, lora_req2]
540+
sampling_params = SamplingParams(max_tokens=200)
541+
542+
outputs = llm.generate(prompts,
543+
sampling_params,
544+
lora_request=lora_requests)
545+
546+
assert len(outputs) == 2
547+
548+
495549
@pytest.mark.parametrize(
496550
"lora_rank,max_lora_rank,description",
497551
[

0 commit comments

Comments
 (0)