77from tensorrt_llm .inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS ,
88 default_multimodal_input_loader )
99
10- example_images = [
11- "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png" ,
12- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" ,
13- "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg" ,
14- ]
15- example_image_prompts = [
16- "Describe the natural environment in the image." ,
17- "Describe the object and the weather condition in the image." ,
18- "Describe the traffic condition on the road in the image." ,
19- ]
20- example_videos = [
21- "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4" ,
22- "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4" ,
23- ]
24- example_video_prompts = [
25- "Tell me what you see in the video briefly." ,
26- "Describe the scene in the video briefly." ,
27- ]
10+ example_medias_and_prompts = {
11+ "image" : {
12+ "media" : [
13+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png" ,
14+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" ,
15+ "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg" ,
16+ ],
17+ "prompt" : [
18+ "Describe the natural environment in the image." ,
19+ "Describe the object and the weather condition in the image." ,
20+ "Describe the traffic condition on the road in the image." ,
21+ ]
22+ },
23+ "video" : {
24+ "media" : [
25+ "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4" ,
26+ "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4" ,
27+ ],
28+ "prompt" : [
29+ "Tell me what you see in the video briefly." ,
30+ "Describe the scene in the video briefly." ,
31+ ]
32+ },
33+ "audio" : {
34+ "media" : [
35+ "https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_the_traffic_sign_in_the_image.wav" ,
36+ "https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav" ,
37+ ],
38+ "prompt" : [
39+ "Transcribe the audio clip into text, please don't add other text." ,
40+ "Transcribe the audio clip into text, please don't add other text." ,
41+ ]
42+ },
43+ "image_audio" : {
44+ "media" : [
45+ [
46+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" ,
47+ "https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
48+ ],
49+ [
50+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png" ,
51+ "https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
52+ ],
53+ ],
54+ "prompt" : [
55+ "Describe the scene in the image briefly." ,
56+ "" ,
57+ ]
58+ }
59+ }
2860
2961
3062def add_multimodal_args (parser ):
@@ -34,7 +66,7 @@ def add_multimodal_args(parser):
3466 help = "Model type." )
3567 parser .add_argument ("--modality" ,
3668 type = str ,
37- choices = ["image" , "video" ],
69+ choices = ["image" , "video" , "audio" , "image_audio" ],
3870 default = "image" ,
3971 help = "Media type." )
4072 parser .add_argument ("--media" ,
@@ -53,11 +85,24 @@ def add_multimodal_args(parser):
5385 return parser
5486
5587
88+ def add_lora_args (parser ):
89+ parser .add_argument ("--load_lora" ,
90+ default = False ,
91+ action = 'store_true' ,
92+ help = "Whether to load the LoRA model." )
93+ parser .add_argument ("--auto_model_name" ,
94+ type = str ,
95+ default = None ,
96+ help = "The auto model name in TRTLLM repo." )
97+ return parser
98+
99+
56100def parse_arguments ():
57101 parser = argparse .ArgumentParser (
58102 description = "Multimodal models with the PyTorch workflow." )
59103 parser = add_llm_args (parser )
60104 parser = add_multimodal_args (parser )
105+ parser = add_lora_args (parser )
61106 args = parser .parse_args ()
62107
63108 args .disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
@@ -71,11 +116,19 @@ def main():
71116 args = parse_arguments ()
72117 # set prompts and media to example prompts and images if they are not provided
73118 if args .prompt is None :
74- args .prompt = example_image_prompts if args .modality == "image" else example_video_prompts
119+ args .prompt = example_medias_and_prompts [ args .modality ][ "prompt" ]
75120 if args .media is None :
76- args .media = example_images if args .modality == "image" else example_videos
121+ args .media = example_medias_and_prompts [args .modality ]["media" ]
122+
123+ lora_config = None
124+ if args .load_lora :
125+ assert args .auto_model_name is not None , "Please provide the auto model name to load LoRA config."
126+ import importlib
127+ models_module = importlib .import_module ('tensorrt_llm._torch.models' )
128+ model_class = getattr (models_module , args .auto_model_name )
129+ lora_config = model_class .lora_config (args .model_dir )
77130
78- llm , sampling_params = setup_llm (args )
131+ llm , sampling_params = setup_llm (args , lora_config = lora_config )
79132
80133 image_format = args .image_format
81134 if args .model_type is not None :
@@ -96,7 +149,16 @@ def main():
96149 num_frames = args .num_frames ,
97150 device = device )
98151
99- outputs = llm .generate (inputs , sampling_params )
152+ lora_request = None
153+ if args .load_lora :
154+ lora_request = model_class .lora_request (len (inputs ), args .modality ,
155+ llm ._hf_model_dir )
156+
157+ outputs = llm .generate (
158+ inputs ,
159+ sampling_params ,
160+ lora_request = lora_request ,
161+ )
100162
101163 for i , output in enumerate (outputs ):
102164 prompt = args .prompt [i ]
0 commit comments