Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about saving peft checkpoint #565

Open
nhanph opened this issue Oct 13, 2023 · 2 comments
Open

Question about saving peft checkpoint #565

nhanph opened this issue Oct 13, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@nhanph
Copy link

nhanph commented Oct 13, 2023

🐛 Describe the bug

From my understand, when saving checkpoints for peft models (see here), trlx removes pytorch_model.bin before calling save_pretrained which makes the removal useless in my opinion.

Is this intentional or we should move the removal code after save_pretrained is called?

Here is an example of a directory resulting from save_pretrained:

adapter_config.json  adapter_model.bin  optimizer.bin  pytorch_model.bin  random_states_0.pkl  special_tokens_map.json  spiece.model  tokenizer_config.json  tokenizer.json

Which trlX version are you using?

0.7.0

Additional system and package information

3.10.12

@nhanph nhanph added the bug Something isn't working label Oct 13, 2023
@maxreciprocate
Copy link
Collaborator

Hello @nhanph!

No, the removal is not useless. If you check contents of python_model.bin immediately after this line:

self.accelerator.save_state(dst_dir, **kwargs)
you will see that it contains the whole model state dictionary, which is not needed. After deleting it and recreating it with model.save_pretrained with heads_only=True, only value heads will be kept there.

>>> list(before_deletion_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias', 'base_model.model.transformer.wte.weight', 'base_model.model.transformer.wpe.weight', 'base_model.model.transformer.h.0.ln_1.weight', 'base_model.model.transformer.h.0.ln_1.bias', 'base_model.model.transformer.h.0.attn.c_attn.weight', 'base_model.model.transformer.h.0.attn.c_attn.bias', 'base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.0.attn.c_proj.weight', 'base_model.model.transformer.h.0.attn.c_proj.bias', 'base_model.model.transformer.h.0.ln_2.weight', 'base_model.model.transformer.h.0.ln_2.bias', 'base_model.model.transformer.h.0.mlp.c_fc.weight', 'base_model.model.transformer.h.0.mlp.c_fc.bias', 'base_model.model.transformer.h.0.mlp.c_proj.weight', 'base_model.model.transformer.h.0.mlp.c_proj.bias', 'base_model.model.transformer.h.1.ln_1.weight', 'base_model.model.transformer.h.1.ln_1.bias', 'base_model.model.transformer.h.1.attn.c_attn.weight', 'base_model.model.transformer.h.1.attn.c_attn.bias', 'base_model.model.transformer.h.1.attn.c_attn.lora_A.default.weight', 'base_model.model.transformer.h.1.attn.c_attn.lora_B.default.weight', 'base_model.model.transformer.h.1.attn.c_proj.weight', 'base_model.model.transformer.h.1.attn.c_proj.bias', 'base_model.model.transformer.h.1.ln_2.weight', 'base_model.model.transformer.h.1.ln_2.bias', 'base_model.model.transformer.h.1.mlp.c_fc.weight', 'base_model.model.transformer.h.1.mlp.c_fc.bias']
>>> list(after_save_pretrained_state_dict.keys())[:32]
['v_head.0.weight', 'v_head.0.bias', 'v_head.2.weight', 'v_head.2.bias']

@nhanph
Copy link
Author

nhanph commented Oct 18, 2023

Thank you @maxreciprocate , I got the point about saving model's value head now.

My original question is from my observation when running ILQL training script that I see a pytorch_model.bin with the size comparable with the original model so I suspect that the base model is also saved. Is there somewhere that the heads_only flag is set to true during checkpointing when using peft_config as I cannot find it set anywhere?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants