Skip to content

Error with new LLaMA V2 in 4bit : mat1 and mat2 shapes cannot be multiplied #610

@dimaischenko

Description

@dimaischenko

Error with new Llama V2 model and 4bit inference:

bitsandbytes: 0.40.2
transformers: 4.31.0
torch: 2.0.1
GPU: RTX 3090 24 Gb
Cuda: 11.7

Load model:

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-13b-hf",
    use_safetensors=True,
    load_in_4bit=True,
    device_map="auto",
)

Try to generate:

text = "Hello how are you?"
input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda()
output = model.generate(inputs=input_ids, max_new_tokens=10)

Error:

File /usr/local/lib/python3.8/dist-packages/transformers/models/llama/modeling_llama.py:295, in <listcomp>(.0)
    292 key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
    293 value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
--> 295 query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
    296 query_states = torch.cat(query_states, dim=-1)
    298 key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (6x5120 and 1x2560)

Metadata

Metadata

Assignees

No one assigned

    Labels

    BugSomething isn't workingHigh Priority(first issues that will be worked on)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions