You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered the following error when loading Valley2 7b with transformers
Code:
”from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("luoruipu1/Valley2-7b", cache_dir='./')“
Error:
Traceback (most recent call last):
File "/remote-home/zhubin/A_LVLM/Valley/tmp.py", line 3, in
model = AutoModelForCausalLM.from_pretrained("luoruipu1/Valley2-7b", cache_dir='./')
File "/root/anaconda3/envs/valley/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 482, in from_pretrained
config, kwargs = AutoConfig.from_pretrained(
File "/root/anaconda3/envs/valley/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py", line 1022, in from_pretrained
config_class = CONFIG_MAPPING[config_dict["model_type"]]
File "/root/anaconda3/envs/valley/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py", line 723, in getitem
raise KeyError(key)
KeyError: 'valley'
Because the model type of valley is not supported by AutoModelForCausalLM, you need to download the model weights locally and then call it with the following code
fromtransformersimportAutoTokenizerfromvalley.model.valleyimportValleyLlamaForCausalLMdefinit_vision_token(model,tokenizer):
vision_config=model.get_model().vision_tower.configvision_config.im_start_token, vision_config.im_end_token=tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
vision_config.vi_start_token, vision_config.vi_end_token=tokenizer.convert_tokens_to_ids([DEFAULT_VI_START_TOKEN, DEFAULT_VI_END_TOKEN])
vision_config.vi_frame_token=tokenizer.convert_tokens_to_ids(DEFAULT_VIDEO_FRAME_TOKEN)
vision_config.im_patch_token=tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
device=torch.device('cuda'iftorch.cuda.is_available() else'cpu')
# input the queryquery="Describe the video concisely."# input the systempromptsystem_prompt="You are Valley, a large language and vision assistant trained by ByteDance. You are able to understand the visual content or video that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail."model_path=THEMODELPATHmodel=ValleyLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
tokenizer=AutoTokenizer.from_pretrained(model_path)
init_vision_token(model,tokenizer)
model=model.to(device)
model.eval()
# we support openai format inputmessage= [ {"role":'system','content':system_prompt},
{"role":"user", "content": 'Hi!'},
{"role":"assistent", "content": 'Hi there! How can I help you today?'},
{"role":"user", "content": query}]
gen_kwargs=dict(
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
)
response=model.completion(tokenizer, args.video_file, message, gen_kwargs, device)
I encountered the following error when loading Valley2 7b with transformers
Code:
”from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("luoruipu1/Valley2-7b", cache_dir='./')“
Error:
Traceback (most recent call last):
File "/remote-home/zhubin/A_LVLM/Valley/tmp.py", line 3, in
model = AutoModelForCausalLM.from_pretrained("luoruipu1/Valley2-7b", cache_dir='./')
File "/root/anaconda3/envs/valley/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 482, in from_pretrained
config, kwargs = AutoConfig.from_pretrained(
File "/root/anaconda3/envs/valley/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py", line 1022, in from_pretrained
config_class = CONFIG_MAPPING[config_dict["model_type"]]
File "/root/anaconda3/envs/valley/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py", line 723, in getitem
raise KeyError(key)
KeyError: 'valley'
pip list
torch 2.0.1
torchvision 0.15.2
tqdm 4.66.1
transformers 4.32.1
triton 2.0.0
typing_extensions 4.7.1
tzdata 2023.3
uc-micro-py 1.0.2
urllib3 2.0.4
uvicorn 0.23.2
valley 0.1.0 Valley
wandb 0.15.8
wavedrom 2.0.3.post3
wcwidth 0.2.6
websockets 11.0.3
wheel 0.38.4
yarl 1.9.2
The text was updated successfully, but these errors were encountered: