Skip to content

Commit 2a0ce00

Browse files
committed
[5464088][fix] Enhance LoRA support in PyTorch model configuration
- Added logging for dtype casting in LoraLayer to ensure compatibility with FP16/BF16. - Updated model configuration to derive the number of LoRA adapters from the model label, improving flexibility in adapter management. Signed-off-by: Venky Ganesh <[email protected]>
1 parent c4535e6 commit 2a0ce00

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

tensorrt_llm/_torch/peft/lora/layer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import torch
55

6+
from tensorrt_llm._utils import logger
7+
68

79
class LoraModuleType(IntEnum):
810
"""Enum class representing different types of modules that can have LoRA adapters.
@@ -119,6 +121,15 @@ def forward(
119121
if len(active_lora_module_ids) == 0:
120122
return None
121123
else:
124+
# Guard: LoRA custom op only supports FP16/BF16 activations.
125+
# If upstream produced FP8 (e.g., FP8 SwiGLU), cast here to avoid runtime failure.
126+
if x.dtype not in (torch.float16, torch.bfloat16):
127+
target_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported(
128+
) else torch.float16
129+
logger.debug(
130+
f"lora_grouped_gemm supports only FP16/BF16. Casting input from {x.dtype} to {target_dtype}."
131+
)
132+
x = x.to(target_dtype).contiguous()
122133
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
123134
x,
124135
lora_params['host_request_types'][:num_seqs],

tests/integration/defs/perf/pytorch_model_config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str,
181181

182182
# lora-specific change for pytorch
183183
if 'pytorch' in model_label and 'loras' in model_label:
184+
# Derive the requested number of adapters from model_label (segment like "loras:X")
185+
lora_count = 1
186+
for part in model_label.split('-'):
187+
if part.startswith('loras:'):
188+
lora_count = max(1, int(part.split(':', 1)[1]))
189+
break
190+
184191
lora_config = {
185192
'lora_config': {
186193
'lora_dir': lora_dirs if lora_dirs is not None else [],
187-
'max_lora_rank': 64
194+
'max_lora_rank': 64,
195+
'max_loras': lora_count,
196+
'max_cpu_loras': lora_count,
188197
}
189198
}
190199
if 'phi_4_multimodal_instruct' in model_label:

0 commit comments

Comments
 (0)