-
Notifications
You must be signed in to change notification settings - Fork 350
Open
Description
def get_model_class(self, model_id: str, model_init_kwargs: dict):
assert "InternVL" in model_id, f"model_id must contain 'InternVL', but got {model_id}"
self.model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
# The model class of InternVL when being mapped has been determined by its config
model_cls = AutoModel
# InternVL should be inputted with "trust_remote_code=True"
model_init_kwargs["trust_remote_code"] = True
# "use_cache" should be removed
model_init_kwargs.pop("use_cache", None)
# "flash_attention_2" should be modified to "use_flash_attn" in InternVL
if "flash_attention_2" in model_init_kwargs.get("attn_implementation", ""):
model_init_kwargs["use_flash_attn"] = True
model_init_kwargs.pop("attn_implementation")
return model_cls
Metadata
Metadata
Assignees
Labels
No labels