Skip to content

Model finetuned using finetune_adapter not directly usable in generaete/chat... How to convert? #78

@RDouglasSharp

Description

@RDouglasSharp

I used the finetune_adapter.py script to generate a tuned model. I tried loading that tuned model back into chat.py, and I get the following error upon load:

RuntimeError: Error(s) in loading state_dict for Parrot:
        Missing key(s) in state_dict: "lm_head.weight", "transformer.wte.weight", "transformer.h.0.norm_1.weight", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.attn.weight", "transformer.h.0.attn.attn.bias", "transformer.h.0.attn.proj.weight",
"transformer.h.0.attn.proj.bias", "transformer.h.0.norm_2.weight", "transformer.h.0.norm_2.bias", "transformer.h.0.mlp.fc.weight", "transformer.h.0.mlp.fc.bias", "transformer.h.0.mlp.proj.weight", "transformer.h.0.mlp.proj.bias", "transformer.h.1.norm_1.weight",
"transformer.h.1.norm_1.bias", "transformer.h.1.attn.attn.weight", "transformer.h.1.attn.attn.bias", "transformer.h.1.attn.proj.weight", "transformer.h.1.attn.proj.bias", "transformer.h.1.norm_2.weight", "transformer.h.1.norm_2.bias", "transformer.h.1.mlp.fc.weight",
"transformer.h.1.mlp.fc.bias", "transformer.h.1.mlp.proj.weight", "transformer.h.1.mlp.proj.bias", "transformer.h.2.norm_1.weight", "transformer.h.2.norm_1.bias", "transformer.h.2.attn.attn.weight", "transformer.h.2.attn.attn.bias", "transformer.h.2.attn.proj.weight",
"transformer.h.2.attn.proj.bias", "transformer.h.2.norm_2.weight", "transformer.h.2.norm_2.bias", "transformer.h.2.mlp.fc.weight", "transformer.h.2.mlp.fc.bias", "transformer.h.2.mlp.proj.weight", "transformer.h.2.mlp.proj.bias", "transformer.h.3.norm_1.weight",
"transformer.h.3.norm_1.bias", "transformer.h.3.attn.attn.weight", "transformer.h.3.attn.attn.bias", "transformer.h.3.attn.proj.weight", "transformer.h.3.attn.proj.bias", "transformer.h.3.norm_2.weight", "transformer.h.3.norm_2.bias", "transformer.h.3.mlp.fc.weight",
"transformer.h.3.mlp.fc.bias", "transformer.h.3.mlp.proj.weight", "transformer.h.3.mlp.proj.bias", "transformer.h.4.norm_1.weight", "transformer.h.4.norm_1.bias", "transformer.h.4.attn.attn.weight", "transformer.h.4.attn.attn.bias", "transformer.h.4.attn.proj.weight",
"transformer.h.4.attn.proj.bias", "transformer.h.4.norm_2.weight", "transformer.h.4.norm_2.bias", "transformer.h.4.mlp.fc.weight", "transformer.h.4.mlp.fc.bias", "transformer.h.4.mlp.proj.weight", "transformer.h.4.mlp.proj.bias", "transformer.h.5.norm_1.weight",
"transformer.h.5.norm_1.bias", "transformer.h.5.attn.attn.weight", "transformer.h.5.attn.attn.bias", "transformer.h.5.attn.proj.weight", "transformer.h.5.attn.proj.bias", "transformer.h.5.norm_2.weight", "transformer.h.5.norm_2.bias", "transformer.h.5.mlp.fc.weight",
"transformer.h.5.mlp.fc.bias", "transformer.h.5.mlp.proj.weight", "transformer.h.5.mlp.proj.bias", "transformer.h.6.norm_1.weight", "transformer.h.6.norm_1.bias", "transformer.h.6.attn.attn.weight", "transformer.h.6.attn.attn.bias", "transformer.h.6.attn.proj.weight",
"transformer.h.6.attn.proj.bias", "transformer.h.6.norm_2.weight", "transformer.h.6.norm_2.bias", "transformer.h.6.mlp.fc.weight", "transformer.h.6.mlp.fc.bias", "transformer.h.6.mlp.proj.weight", "transformer.h.6.mlp.proj.bias", "transformer.h.7.norm_1.weight",
"transformer.h.7.norm_1.bias", "transformer.h.7.attn.attn.weight", "transformer.h.7.attn.attn.bias", "transformer.h.7.attn.proj.weight", "transformer.h.7.attn.proj.bias", "transformer.h.7.norm_2.weight", "transformer.h.7.norm_2.bias", "transformer.h.7.mlp.fc.weight",
"transformer.h.7.mlp.fc.bias", "transformer.h.7.mlp.proj.weight", "transformer.h.7.mlp.proj.bias", "transformer.h.8.norm_1.weight", "transformer.h.8.norm_1.bias", "transformer.h.8.attn.attn.weight", "transformer.h.8.attn.attn.bias", "transformer.h.8.attn.proj.weight",
"transformer.h.8.attn.proj.bias", "transformer.h.8.norm_2.weight", "transformer.h.8.norm_2.bias", "transformer.h.8.mlp.fc.weight", "transformer.h.8.mlp.fc.bias", "transformer.h.8.mlp.proj.weight", "transformer.h.8.mlp.proj.bias", "transformer.h.9.norm_1.weight",
"transformer.h.9.norm_1.bias", "transformer.h.9.attn.attn.weight", "transformer.h.9.attn.attn.bias", "transformer.h.9.attn.proj.weight", "transformer.h.9.attn.proj.bias", "transformer.h.9.norm_2.weight", "transformer.h.9.norm_2.bias", "transformer.h.9.mlp.fc.weight",
"transformer.h.9.mlp.fc.bias", "transformer.h.9.mlp.proj.weight", "transformer.h.9.mlp.proj.bias", "transformer.h.10.norm_1.weight", "transformer.h.10.norm_1.bias", "transformer.h.10.attn.attn.weight", "transformer.h.10.attn.attn.bias",
"transformer.h.10.attn.proj.weight", "transformer.h.10.attn.proj.bias", "transformer.h.10.norm_2.weight", "transformer.h.10.norm_2.bias", "transformer.h.10.mlp.fc.weight", "transformer.h.10.mlp.fc.bias", "transformer.h.10.mlp.proj.weight",
"transformer.h.10.mlp.proj.bias", "transformer.h.11.norm_1.weight", "transformer.h.11.norm_1.bias", "transformer.h.11.attn.attn.weight", "transformer.h.11.attn.attn.bias", "transformer.h.11.attn.proj.weight", "transformer.h.11.attn.proj.bias",
"transformer.h.11.norm_2.weight", "transformer.h.11.norm_2.bias", "transformer.h.11.mlp.fc.weight", "transformer.h.11.mlp.fc.bias", "transformer.h.11.mlp.proj.weight", "transformer.h.11.mlp.proj.bias", "transformer.h.12.norm_1.weight", "transformer.h.12.norm_1.bias",
"transformer.h.12.attn.attn.weight", "transformer.h.12.attn.attn.bias", "transformer.h.12.attn.proj.weight", "transformer.h.12.attn.proj.bias", "transformer.h.12.norm_2.weight", "transformer.h.12.norm_2.bias", "transformer.h.12.mlp.fc.weight",
"transformer.h.12.mlp.fc.bias", "transformer.h.12.mlp.proj.weight", "transformer.h.12.mlp.proj.bias", "transformer.h.13.norm_1.weight", "transformer.h.13.norm_1.bias", "transformer.h.13.attn.attn.weight", "transformer.h.13.attn.attn.bias",
"transformer.h.13.attn.proj.weight", "transformer.h.13.attn.proj.bias", "transformer.h.13.norm_2.weight", "transformer.h.13.norm_2.bias", "transformer.h.13.mlp.fc.weight", "transformer.h.13.mlp.fc.bias", "transformer.h.13.mlp.proj.weight",
"transformer.h.13.mlp.proj.bias", "transformer.h.14.norm_1.weight", "transformer.h.14.norm_1.bias", "transformer.h.14.attn.attn.weight", "transformer.h.14.attn.attn.bias", "transformer.h.14.attn.proj.weight", "transformer.h.14.attn.proj.bias",
"transformer.h.14.norm_2.weight", "transformer.h.14.norm_2.bias", "transformer.h.14.mlp.fc.weight", "transformer.h.14.mlp.fc.bias", "transformer.h.14.mlp.proj.weight", "transformer.h.14.mlp.proj.bias", "transformer.h.15.norm_1.weight", "transformer.h.15.norm_1.bias",
"transformer.h.15.attn.attn.weight", "transformer.h.15.attn.attn.bias", "transformer.h.15.attn.proj.weight", "transformer.h.15.attn.proj.bias", "transformer.h.15.norm_2.weight", "transformer.h.15.norm_2.bias", "transformer.h.15.mlp.fc.weight",
"transformer.h.15.mlp.fc.bias", "transformer.h.15.mlp.proj.weight", "transformer.h.15.mlp.proj.bias", "transformer.ln_f.weight", "transformer.ln_f.bias".
        Unexpected key(s) in state_dict: "transformer.h.2.attn.gating_factor", "transformer.h.2.attn.adapter_wte.weight", "transformer.h.3.attn.gating_factor", "transformer.h.3.attn.adapter_wte.weight", "transformer.h.4.attn.gating_factor",
"transformer.h.4.attn.adapter_wte.weight", "transformer.h.5.attn.gating_factor", "transformer.h.5.attn.adapter_wte.weight", "transformer.h.6.attn.gating_factor", "transformer.h.6.attn.adapter_wte.weight", "transformer.h.7.attn.gating_factor",
"transformer.h.7.attn.adapter_wte.weight", "transformer.h.8.attn.gating_factor", "transformer.h.8.attn.adapter_wte.weight", "transformer.h.9.attn.gating_factor", "transformer.h.9.attn.adapter_wte.weight", "transformer.h.10.attn.gating_factor",
"transformer.h.10.attn.adapter_wte.weight", "transformer.h.11.attn.gating_factor", "transformer.h.11.attn.adapter_wte.weight", "transformer.h.12.attn.gating_factor", "transformer.h.12.attn.adapter_wte.weight", "transformer.h.13.attn.gating_factor",
"transformer.h.13.attn.adapter_wte.weight", "transformer.h.14.attn.gating_factor", "transformer.h.14.attn.adapter_wte.weight", "transformer.h.15.attn.gating_factor", "transformer.h.15.attn.adapter_wte.weight".

What am I doing wrong? How do I convert a tuned model checkpoint to what is expected by generate / chat?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions