Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecated shard_checkpoint's replacement save_torch_state_dict does not save tied embeddings #35080

Closed
4 tasks
casper-hansen opened this issue Dec 4, 2024 · 13 comments
Labels

Comments

@casper-hansen
Copy link

casper-hansen commented Dec 4, 2024

System Info

- `transformers` version: 4.46.3
- Platform: macOS-14.4-arm64-arm-64bit
- Python version: 3.10.13
- Huggingface_hub version: 0.26.3
- Safetensors version: 0.4.2
- Accelerate version: 0.26.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.2.2 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@SunMarc @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The following shard_checkpoint has been deprecated in favor of save_torch_state_dict, so that's why I updated the saving mechanism in AutoAWQ to use the new method in casper-hansen/AutoAWQ#644. However, it seems there is a problem where tied embeddings are not correctly saved and thus causing problems during load time in vLLM and potentially other places not identified yet.

from transformers.modeling_utils import shard_checkpoint
from huggingface_hub import save_torch_state_dict

Overview from casper-hansen/AutoAWQ#665 where you can also see the full reproduction scripts and the issues caused.

Model Files model.embed_tokens lm_head
Qwen/Qwen2.5-1.5B-Instruct yes no
transformers==4.46.3 load and save yes no
autoawq==0.2.6 (shard_checkpoint) yes yes
autoawq==0.2.7.post2 (save_torch_state_dict) no yes

Expected behavior

shard_checkpoint seems to have saved tied weights which are important in a lot of engines compatible with Huggingface transformers. The expected behavior is therefore that save_torch_state_dict would also do this since we are being migrated to use this new method.

@SunMarc
Copy link
Member

SunMarc commented Dec 4, 2024

Thanks for the report ! The problem is that save_torch_state_dict removes the duplicates keys randomly compared to transformers. Hence, as you can see it removed lm_head instead of model.embed_tokens.
To fix the issue, we need to clean the state_dict correctly. To do that, we have model._tied_weights_keys that returns the preferred duplicate keys to discard (e.g. ["lm_head"]).

You can either do the cleaning step yourself (copy _clean_state_dict_for_safetensors from huggingface_hub + pass discard_name in _remove_duplicate_names) or we can add a new argument in save_torch_state_dict but that would require the latest version of huggingface hub. cc @Wauplin

Also, I wanted to know why you decided not to use save_pretrained method ?

@casper-hansen
Copy link
Author

You can either do the cleaning step yourself (copy _clean_state_dict_for_safetensors from huggingface_hub + pass discard_name in _remove_duplicate_names) or we can add a new argument in save_torch_state_dict but that would require the latest version of huggingface hub. cc @Wauplin

I would prefer a fix in huggingface hub. Why was the behavior changed from shard_checkpoint to save_torch_state_dict?

I'm generally not sure if I find it feasible to use the workaround you suggested. Doesn't that require that I know which keys per model that should not be discarded?

Also, I wanted to know why you decided not to use save_pretrained method ?

It's been a long time since I implemented save_pretrained, ranging way back to when I created AutoAWQ. So I can honestly not remember, but there must have been some issue since it would be the easiest way.

@Wauplin
Copy link
Contributor

Wauplin commented Dec 4, 2024

cc @hanouticelina as you've been working on this lately. Would it be possible to check what we can do in save_torch_state_dict to handle this in a clean way?

@hanouticelina
Copy link

hanouticelina commented Dec 4, 2024

A fix would be to modify save_torch_state_dict to remove duplicates only when we explicitly know which tensors should be discarded, (typically if we have access to model. _tied_weights_keys). This can be achieved by adding an argument to specify these tensors. When this information isn't available, i think the safest thing to do is to keep all tensors, as it was done in shard_checkpoint.

The trade-off is that the resulting safetensors file might be larger. Also as far as i know, some frameworks (like tensorflow) don't support shared tensors, which could limit cross framework compatibility.

@casper-hansen
Copy link
Author

When this information isn't available, i think the safest thing to do is to keep all tensors, as it was done in shard_checkpoint.

I would appreciate if this change could make it into the new method :)

@hanouticelina
Copy link

after reviewing the safetensors implementation more in detail, I suggest a different solution than what I proposed in my previous comment. It turns out it's not possible to serialize a state_dict with duplicates.

looking at how AutoAWQ uses save_torch_state_dict, we have access to the model itself. So to fix this issue, we will add a helper function in huggingface_hub to handle the priority for discarding duplicate keys (similar to transformers' save_pretrained()), and this helper will be called by huggingface_hub.save_torch_model().

Once this is fixed in huggingface_hub, you'll be able to simply use save_torch_model() instead of save_torch_state_dict():

+ from huggingface_hub import save_torch_model
...
- save_torch_state_dict(
-     state_dict=self.model.state_dict(),
+ save_torch_model(
+     model=self.model,
      save_directory=save_dir,
      max_shard_size=shard_size,
      safe_serialization=safetensors,
      force_contiguous=True,
)

would this solution work for you?

@casper-hansen
Copy link
Author

@hanouticelina Yes, that would work for me, as long as the model is saved correctly in the case of tied weights. Currently, AutoAWQ quantization is broken from version 0.2.7 because of this issue. So users will have to take extra steps as seen in casper-hansen/AutoAWQ#665 until this fix can be landed in AutoAWQ.

@SunMarc
Copy link
Member

SunMarc commented Dec 6, 2024

Could you test with this PR to see if this solves the issue @casper-hansen ?

@hanouticelina
Copy link

@casper-hansen After reviewing this further, we decided not to add the duplicate keys handling logic directly in huggingface_hub since this logic is framework-specific and should be handled by the user. you can find an example in the PR description using a transformers.PreTrainedModel which you can use directly here.

@casper-hansen
Copy link
Author

@hanouticelina @SunMarc this looks good. Is my understanding correct that passing in model._tied_weights_keys will work in my AutoAWQ case of saving the model?

@casper-hansen
Copy link
Author

Ok, so I validated this works when I pass in model._tied_weights_keys. I would appreciate a new quick post release for huggingface_hub if possible when this PR is landed. Until then, I will put the huggingface_hub PR in the autoawq main branch requirements.

import os
import safetensors

quant_path = "Qwen2.5-0.5B-Instruct-awq"
tensors = {}
with safetensors.safe_open(
    os.path.join(quant_path, "model.safetensors"), framework="pt", device="cpu"
) as f:
    print("model.embed_tokens.weight" in f.keys())

@hanouticelina
Copy link

@casper-hansen we've just released a patch for huggingface_hub, you can now update your requirements to use huggingface_hub>=0.26.5.

feel free to ping us if there any additional question or issue related to that!

@casper-hansen
Copy link
Author

Thanks @hanouticelina for the quick fix + release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants