@@ -57,6 +57,7 @@ def __init__(
57
57
batch_size : Optional [Union [int , str ]] = 1 ,
58
58
trust_remote_code : Optional [bool ] = False ,
59
59
revision = None ,
60
+ model_name = None ,
60
61
attn_implementation = best_fit_attn_implementation ,
61
62
use_flash_attention_2 = True ,
62
63
device_map = "auto" ,
@@ -83,8 +84,20 @@ def __init__(
83
84
llava_model_args ["attn_implementation" ] = attn_implementation
84
85
if customized_config :
85
86
llava_model_args ["customized_config" ] = customized_config
86
- llava_model_args ["use_flash_attention_2" ] = False
87
- self ._tokenizer , self ._model , self ._image_processor , self ._max_length = load_pretrained_model (pretrained , None , get_model_name_from_path (pretrained ), device_map = self .device_map , ** llava_model_args )
87
+ if attn_implementation is not None :
88
+ llava_model_args ["attn_implementation" ] = attn_implementation
89
+ if "use_flash_attention_2" in kwargs :
90
+ llava_model_args ["use_flash_attention_2" ] = kwargs ["use_flash_attention_2" ]
91
+
92
+ model_name = model_name if model_name is not None else get_model_name_from_path (pretrained )
93
+ try :
94
+ # Try to load the model with the multimodal argument
95
+ self ._tokenizer , self ._model , self ._image_processor , self ._max_length = load_pretrained_model (pretrained , None , model_name , device_map = self .device_map , ** llava_model_args )
96
+ except TypeError :
97
+ # for older versions of LLaVA that don't have multimodal argument
98
+ llava_model_args .pop ("multimodal" , None )
99
+ self ._tokenizer , self ._model , self ._image_processor , self ._max_length = load_pretrained_model (pretrained , None , model_name , device_map = self .device_map , ** llava_model_args )
100
+
88
101
self ._config = self ._model .config
89
102
self .model .eval ()
90
103
self .model .tie_weights ()
0 commit comments