Skip to content

Commit 2d2b8ba

Browse files
authored
feat: TRTLLM-5574 Add phi-4-multimodal pytorch-backend support (#5644)
Signed-off-by: Wanli Jiang <[email protected]>
1 parent e09e409 commit 2d2b8ba

File tree

20 files changed

+1277
-56
lines changed

20 files changed

+1277
-56
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def parse_arguments():
145145
return args
146146

147147

148-
def setup_llm(args):
148+
def setup_llm(args, **kwargs):
149149
kv_cache_config = KvCacheConfig(
150150
enable_block_reuse=not args.disable_kv_cache_reuse,
151151
free_gpu_memory_fraction=args.kv_cache_fraction,
@@ -222,7 +222,9 @@ def setup_llm(args):
222222
speculative_config=spec_config,
223223
trust_remote_code=args.trust_remote_code,
224224
gather_generation_logits=args.return_generation_logits,
225-
max_beam_width=args.max_beam_width)
225+
max_beam_width=args.max_beam_width,
226+
**kwargs,
227+
)
226228

227229
sampling_params = SamplingParams(
228230
max_tokens=args.max_tokens,

examples/llm-api/quickstart_multimodal.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,56 @@
77
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
88
default_multimodal_input_loader)
99

10-
example_images = [
11-
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
12-
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
13-
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
14-
]
15-
example_image_prompts = [
16-
"Describe the natural environment in the image.",
17-
"Describe the object and the weather condition in the image.",
18-
"Describe the traffic condition on the road in the image.",
19-
]
20-
example_videos = [
21-
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
22-
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
23-
]
24-
example_video_prompts = [
25-
"Tell me what you see in the video briefly.",
26-
"Describe the scene in the video briefly.",
27-
]
10+
example_medias_and_prompts = {
11+
"image": {
12+
"media": [
13+
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
14+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
15+
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
16+
],
17+
"prompt": [
18+
"Describe the natural environment in the image.",
19+
"Describe the object and the weather condition in the image.",
20+
"Describe the traffic condition on the road in the image.",
21+
]
22+
},
23+
"video": {
24+
"media": [
25+
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
26+
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
27+
],
28+
"prompt": [
29+
"Tell me what you see in the video briefly.",
30+
"Describe the scene in the video briefly.",
31+
]
32+
},
33+
"audio": {
34+
"media": [
35+
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_the_traffic_sign_in_the_image.wav",
36+
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav",
37+
],
38+
"prompt": [
39+
"Transcribe the audio clip into text, please don't add other text.",
40+
"Transcribe the audio clip into text, please don't add other text.",
41+
]
42+
},
43+
"image_audio": {
44+
"media": [
45+
[
46+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
47+
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
48+
],
49+
[
50+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
51+
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
52+
],
53+
],
54+
"prompt": [
55+
"Describe the scene in the image briefly.",
56+
"",
57+
]
58+
}
59+
}
2860

2961

3062
def add_multimodal_args(parser):
@@ -34,7 +66,7 @@ def add_multimodal_args(parser):
3466
help="Model type.")
3567
parser.add_argument("--modality",
3668
type=str,
37-
choices=["image", "video"],
69+
choices=["image", "video", "audio", "image_audio"],
3870
default="image",
3971
help="Media type.")
4072
parser.add_argument("--media",
@@ -53,11 +85,24 @@ def add_multimodal_args(parser):
5385
return parser
5486

5587

88+
def add_lora_args(parser):
89+
parser.add_argument("--load_lora",
90+
default=False,
91+
action='store_true',
92+
help="Whether to load the LoRA model.")
93+
parser.add_argument("--auto_model_name",
94+
type=str,
95+
default=None,
96+
help="The auto model name in TRTLLM repo.")
97+
return parser
98+
99+
56100
def parse_arguments():
57101
parser = argparse.ArgumentParser(
58102
description="Multimodal models with the PyTorch workflow.")
59103
parser = add_llm_args(parser)
60104
parser = add_multimodal_args(parser)
105+
parser = add_lora_args(parser)
61106
args = parser.parse_args()
62107

63108
args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
@@ -71,11 +116,19 @@ def main():
71116
args = parse_arguments()
72117
# set prompts and media to example prompts and images if they are not provided
73118
if args.prompt is None:
74-
args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts
119+
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
75120
if args.media is None:
76-
args.media = example_images if args.modality == "image" else example_videos
121+
args.media = example_medias_and_prompts[args.modality]["media"]
122+
123+
lora_config = None
124+
if args.load_lora:
125+
assert args.auto_model_name is not None, "Please provide the auto model name to load LoRA config."
126+
import importlib
127+
models_module = importlib.import_module('tensorrt_llm._torch.models')
128+
model_class = getattr(models_module, args.auto_model_name)
129+
lora_config = model_class.lora_config(args.model_dir)
77130

78-
llm, sampling_params = setup_llm(args)
131+
llm, sampling_params = setup_llm(args, lora_config=lora_config)
79132

80133
image_format = args.image_format
81134
if args.model_type is not None:
@@ -96,7 +149,16 @@ def main():
96149
num_frames=args.num_frames,
97150
device=device)
98151

99-
outputs = llm.generate(inputs, sampling_params)
152+
lora_request = None
153+
if args.load_lora:
154+
lora_request = model_class.lora_request(len(inputs), args.modality,
155+
llm._hf_model_dir)
156+
157+
outputs = llm.generate(
158+
inputs,
159+
sampling_params,
160+
lora_request=lora_request,
161+
)
100162

101163
for i, output in enumerate(outputs):
102164
prompt = args.prompt[i]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@ ninja
5959
etcd3
6060
blake3
6161
llguidance==0.7.29
62+
soundfile

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ class RopeParams:
351351
beta_slow: int = 1
352352
mscale: float = 1.0
353353
mscale_all_dim: float = 0.0
354+
short_factor: Optional[Tuple[float]] = None
355+
long_factor: Optional[Tuple[float]] = None
354356

355357
@staticmethod
356358
def from_config(config) -> "RopeParams":
@@ -386,12 +388,18 @@ def from_config(config) -> "RopeParams":
386388
"low_freq_factor", 1.0)
387389
rope_params.high_freq_factor = rope_scaling.get(
388390
"high_freq_factor", 4.0)
389-
rope_params.original_max_positions = rope_scaling.get(
390-
"original_max_position_embeddings", 1024)
391+
rope_params.original_max_positions = getattr(
392+
config,
393+
"original_max_position_embeddings", None) or rope_scaling.get(
394+
"original_max_position_embeddings", None) or 1024
391395
rope_params.beta_fast = rope_scaling.get("beta_fast", 32)
392396
rope_params.beta_slow = rope_scaling.get("beta_slow", 1)
393397
rope_params.mscale = rope_scaling.get("mscale", 1.0)
394398
rope_params.mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
399+
if "short_factor" in rope_scaling:
400+
rope_params.short_factor = tuple(rope_scaling["short_factor"])
401+
if "long_factor" in rope_scaling:
402+
rope_params.long_factor = tuple(rope_scaling["long_factor"])
395403
# Workaround for DeepSeek V3 Lite since its rope_scaling is null in config.json.
396404
elif config.model_type == "deepseek_v3":
397405
rope_params.scale_type = RotaryScalingType.yarn
@@ -428,7 +436,14 @@ def create_rope_const_params(self, interleave: bool = True):
428436
self.mscale_all_dim,
429437
)
430438
elif self.scale_type == RotaryScalingType.longrope:
431-
raise NotImplementedError("Long RoPE is not supported.")
439+
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin(
440+
num_pos=self.max_positions,
441+
dim=self.dim,
442+
theta=self.theta,
443+
original_max_pos=self.original_max_positions,
444+
short_factor=self.short_factor,
445+
long_factor=self.long_factor,
446+
)
432447
else:
433448
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
434449
self.max_positions,

tensorrt_llm/_torch/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from .modeling_nemotron import NemotronForCausalLM
1616
from .modeling_nemotron_h import NemotronHForCausalLM
1717
from .modeling_nemotron_nas import NemotronNASForCausalLM
18+
from .modeling_phi3 import Phi3ForCausalLM
19+
from .modeling_phi4mm import Phi4MMForCausalLM
1820
from .modeling_qwen import (Qwen2ForCausalLM, Qwen2ForProcessRewardModel,
1921
Qwen2ForRewardModel)
2022
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
@@ -42,6 +44,8 @@
4244
"NemotronForCausalLM",
4345
"NemotronHForCausalLM",
4446
"NemotronNASForCausalLM",
47+
"Phi3ForCausalLM",
48+
"Phi4MMForCausalLM",
4549
"Qwen2ForCausalLM",
4650
"Qwen2ForProcessRewardModel",
4751
"Qwen2ForRewardModel",

tensorrt_llm/_torch/models/modeling_multimodal_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def fuse_input_embeds(
6464
mm_token_mask = input_ids >= vocab_size
6565
text_token_mask = input_ids < vocab_size
6666
else:
67+
mm_token_ids = mm_token_ids.to(input_ids.device)
6768
mm_token_mask = torch.isin(input_ids, mm_token_ids)
6869
text_token_mask = ~mm_token_mask
6970
text_token_indices = torch.where(text_token_mask)[0]

0 commit comments

Comments
 (0)