Skip to content

Commit 8aaa828

Browse files
Luodiankcz358
authored andcommitted
Add model_name parameter to Llava constructor
1 parent 7847dc4 commit 8aaa828

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

Diff for: lmms_eval/models/llava.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
batch_size: Optional[Union[int, str]] = 1,
5858
trust_remote_code: Optional[bool] = False,
5959
revision=None,
60+
model_name=None,
6061
attn_implementation=best_fit_attn_implementation,
6162
use_flash_attention_2=True,
6263
device_map="auto",
@@ -83,8 +84,20 @@ def __init__(
8384
llava_model_args["attn_implementation"] = attn_implementation
8485
if customized_config:
8586
llava_model_args["customized_config"] = customized_config
86-
llava_model_args["use_flash_attention_2"] = False
87-
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, get_model_name_from_path(pretrained), device_map=self.device_map, **llava_model_args)
87+
if attn_implementation is not None:
88+
llava_model_args["attn_implementation"] = attn_implementation
89+
if "use_flash_attention_2" in kwargs:
90+
llava_model_args["use_flash_attention_2"] = kwargs["use_flash_attention_2"]
91+
92+
model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)
93+
try:
94+
# Try to load the model with the multimodal argument
95+
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)
96+
except TypeError:
97+
# for older versions of LLaVA that don't have multimodal argument
98+
llava_model_args.pop("multimodal", None)
99+
self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args)
100+
88101
self._config = self._model.config
89102
self.model.eval()
90103
self.model.tie_weights()

0 commit comments

Comments
 (0)