Skip to content

Commit 262e13d

Browse files
committed
[None][chore] Add unit test for Gemma3 lora
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 16febef commit 262e13d

File tree

3 files changed

+62
-12
lines changed

3 files changed

+62
-12
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
1515
PositionalEmbeddingParams,
1616
PredefinedAttentionMask, RopeParams)
17-
from ..distributed import AllReduceParams
1817
from ..model_config import ModelConfig
1918
from ..modules.attention import Attention
2019
from ..modules.decoder_layer import DecoderLayer
@@ -105,9 +104,6 @@ def forward(
105104
hidden_states: torch.Tensor,
106105
attn_metadata: AttentionMetadata,
107106
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
108-
mrope_config: Optional[dict] = None,
109-
all_reduce_params: Optional[AllReduceParams] = None,
110-
lora_params: Optional[dict] = None,
111107
attention_mask_data: Optional[torch.Tensor] = None,
112108
**kwargs,
113109
) -> torch.Tensor:
@@ -121,9 +117,6 @@ def forward(
121117
hidden_states=hidden_states,
122118
attn_metadata=attn_metadata,
123119
attention_mask=attention_mask,
124-
mrope_config=mrope_config,
125-
all_reduce_params=all_reduce_params,
126-
lora_params=lora_params,
127120
attention_window_size=self.attention_window_size,
128121
attention_mask_data=attention_mask_data,
129122
**kwargs)
@@ -209,7 +202,6 @@ def forward(
209202
attn_metadata: AttentionMetadata,
210203
residual: Optional[torch.Tensor] = None,
211204
attention_mask_data: Optional[torch.Tensor] = None,
212-
lora_params: Optional[dict] = None,
213205
**kwargs,
214206
) -> torch.Tensor:
215207

@@ -222,14 +214,14 @@ def forward(
222214
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
223215
is not None else PredefinedAttentionMask.CAUSAL,
224216
attention_mask_data=attention_mask_data,
225-
lora_params=lora_params,
226217
**kwargs,
227218
)
228219
hidden_states = self.post_attention_layernorm(hidden_states)
229220
hidden_states = residual + hidden_states
230221
residual = hidden_states
231222
hidden_states = self.pre_feedforward_layernorm(hidden_states)
232-
hidden_states = self.mlp(hidden_states, lora_params=lora_params)
223+
hidden_states = self.mlp(hidden_states,
224+
lora_params=kwargs.get("lora_params", None))
233225
hidden_states = self.post_feedforward_layernorm(hidden_states)
234226
hidden_states = residual + hidden_states
235227

@@ -270,7 +262,6 @@ def forward(
270262
inputs_embeds: Optional[torch.FloatTensor] = None,
271263
local_attention_mask_data: Optional[torch.Tensor] = None,
272264
global_attention_mask_data: Optional[torch.Tensor] = None,
273-
lora_params: Optional[dict] = None,
274265
**kwargs,
275266
) -> torch.Tensor:
276267
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -291,7 +282,7 @@ def forward(
291282
attention_mask_data=local_attention_mask_data
292283
if decoder_layer.self_attn.is_sliding else
293284
global_attention_mask_data,
294-
lora_params=lora_params,
285+
**kwargs,
295286
)
296287

297288
hidden_states = self.norm(hidden_states)
@@ -465,6 +456,7 @@ def forward(
465456
inputs_embeds=inputs_embeds,
466457
local_attention_mask_data=local_attention_mask_data,
467458
global_attention_mask_data=global_attention_mask_data,
459+
**kwargs,
468460
)
469461

470462
return self.logits_processor.forward(

tests/integration/test_lists/qa/examples_test_list.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/
551551
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
552552
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
553553
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
554+
unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora
554555
examples/test_medusa.py::test_codellama_medusa_1gpu[CodeLlama-7b-Instruct]
555556
examples/test_medusa.py::test_mistral_medusa_1gpu[mistral-7b-v0.1]
556557
examples/test_medusa.py::test_qwen_medusa_1gpu[qwen_7b_chat]

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from tensorrt_llm import LLM
4+
from tensorrt_llm.llmapi import KvCacheConfig
45
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
56
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
67
from tensorrt_llm.sampling_params import SamplingParams
@@ -492,6 +493,62 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
492493
assert len(outputs) == 2
493494

494495

496+
def test_gemma3_1b_instruct_multi_lora() -> None:
497+
model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it"
498+
499+
target_modules = ['attn_q', 'attn_k', 'attn_v']
500+
501+
# Set up temporary directory for LoRA adapters
502+
with tempfile.TemporaryDirectory() as lora_dir:
503+
print("Creating dummy LoRAs...")
504+
505+
model = AutoModelForCausalLM.from_pretrained(model_dir,
506+
torch_dtype=torch.bfloat16,
507+
device_map="auto")
508+
hf_modules = ["q_proj", "k_proj", "v_proj"]
509+
peft_lora_config = PeftLoraConfig(r=8,
510+
target_modules=hf_modules,
511+
bias="none",
512+
task_type="CAUSAL_LM")
513+
lora_paths = []
514+
for i in range(2):
515+
lora_model = get_peft_model(model, peft_lora_config)
516+
for param in lora_model.parameters():
517+
param.data.zero_()
518+
lora_path = f"{lora_dir}/lora_{i}"
519+
lora_model.save_pretrained(lora_path)
520+
lora_paths.append(lora_path)
521+
522+
trtllm_lora_config = LoraConfig(lora_dir=lora_paths,
523+
lora_target_modules=target_modules,
524+
max_lora_rank=8,
525+
max_loras=2,
526+
max_cpu_loras=2)
527+
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
528+
kv_cache_config = KvCacheConfig(
529+
enable_block_reuse=False,
530+
enable_partial_reuse=False,
531+
)
532+
llm = LLM(model_dir,
533+
lora_config=trtllm_lora_config,
534+
kv_cache_config=kv_cache_config)
535+
536+
prompts = [
537+
"Is it ok to fill diesel in a petrol car?",
538+
"What is the capital of France?",
539+
]
540+
lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
541+
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
542+
lora_requests = [lora_req1, lora_req2]
543+
sampling_params = SamplingParams(max_tokens=200)
544+
545+
outputs = llm.generate(prompts,
546+
sampling_params,
547+
lora_request=lora_requests)
548+
549+
assert len(outputs) == 2
550+
551+
495552
@pytest.mark.parametrize(
496553
"lora_rank,max_lora_rank,description",
497554
[

0 commit comments

Comments
 (0)