diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index 10acffae9d6..9ed0a71da1f 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -14,7 +14,6 @@ from ..attention_backend.interface import (AttentionMask, CustomAttentionMask, PositionalEmbeddingParams, PredefinedAttentionMask, RopeParams) -from ..distributed import AllReduceParams from ..model_config import ModelConfig from ..modules.attention import Attention from ..modules.decoder_layer import DecoderLayer @@ -105,9 +104,6 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL, - mrope_config: Optional[dict] = None, - all_reduce_params: Optional[AllReduceParams] = None, - lora_params: Optional[dict] = None, attention_mask_data: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -121,9 +117,6 @@ def forward( hidden_states=hidden_states, attn_metadata=attn_metadata, attention_mask=attention_mask, - mrope_config=mrope_config, - all_reduce_params=all_reduce_params, - lora_params=lora_params, attention_window_size=self.attention_window_size, attention_mask_data=attention_mask_data, **kwargs) @@ -209,7 +202,6 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor] = None, attention_mask_data: Optional[torch.Tensor] = None, - lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: @@ -222,14 +214,14 @@ def forward( attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data is not None else PredefinedAttentionMask.CAUSAL, attention_mask_data=attention_mask_data, - lora_params=lora_params, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.pre_feedforward_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states, lora_params=lora_params) + hidden_states = self.mlp(hidden_states, + lora_params=kwargs.get("lora_params", None)) hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states @@ -270,7 +262,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, local_attention_mask_data: Optional[torch.Tensor] = None, global_attention_mask_data: Optional[torch.Tensor] = None, - lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): @@ -291,7 +282,7 @@ def forward( attention_mask_data=local_attention_mask_data if decoder_layer.self_attn.is_sliding else global_attention_mask_data, - lora_params=lora_params, + **kwargs, ) hidden_states = self.norm(hidden_states) @@ -465,6 +456,7 @@ def forward( inputs_embeds=inputs_embeds, local_attention_mask_data=local_attention_mask_data, global_attention_mask_data=global_attention_mask_data, + **kwargs, ) return self.logits_processor.forward( diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index c93c81a169a..f9ea731fa88 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -551,6 +551,7 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/ test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B] test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] +unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora examples/test_medusa.py::test_codellama_medusa_1gpu[CodeLlama-7b-Instruct] examples/test_medusa.py::test_mistral_medusa_1gpu[mistral-7b-v0.1] examples/test_medusa.py::test_qwen_medusa_1gpu[qwen_7b_chat] diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index c9e53286908..16f43c3885d 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -1,6 +1,7 @@ import pytest from tensorrt_llm import LLM +from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer from tensorrt_llm.sampling_params import SamplingParams @@ -492,6 +493,62 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: assert len(outputs) == 2 +def test_gemma3_1b_instruct_multi_lora() -> None: + model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it" + + target_modules = ['attn_q', 'attn_k', 'attn_v'] + + # Set up temporary directory for LoRA adapters + with tempfile.TemporaryDirectory() as lora_dir: + print("Creating dummy LoRAs...") + + model = AutoModelForCausalLM.from_pretrained(model_dir, + torch_dtype=torch.bfloat16, + device_map="auto") + hf_modules = ["q_proj", "k_proj", "v_proj"] + peft_lora_config = PeftLoraConfig(r=8, + target_modules=hf_modules, + bias="none", + task_type="CAUSAL_LM") + lora_paths = [] + for i in range(2): + lora_model = get_peft_model(model, peft_lora_config) + for param in lora_model.parameters(): + param.data.zero_() + lora_path = f"{lora_dir}/lora_{i}" + lora_model.save_pretrained(lora_path) + lora_paths.append(lora_path) + + trtllm_lora_config = LoraConfig(lora_dir=lora_paths, + lora_target_modules=target_modules, + max_lora_rank=8, + max_loras=2, + max_cpu_loras=2) + # Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size. + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + enable_partial_reuse=False, + ) + llm = LLM(model_dir, + lora_config=trtllm_lora_config, + kv_cache_config=kv_cache_config) + + prompts = [ + "Is it ok to fill diesel in a petrol car?", + "What is the capital of France?", + ] + lora_req1 = LoRARequest("lora-1", 0, lora_paths[0]) + lora_req2 = LoRARequest("lora-2", 1, lora_paths[1]) + lora_requests = [lora_req1, lora_req2] + sampling_params = SamplingParams(max_tokens=200) + + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests) + + assert len(outputs) == 2 + + @pytest.mark.parametrize( "lora_rank,max_lora_rank,description", [