Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Covert a safetensor checkpoint from Hugging Face hub #1662

Merged
merged 32 commits into from
Jun 24, 2024

Conversation

ariG23498
Copy link
Collaborator

Problem Statement

With Keras NLP, you can download models from the Hugging Face Hub using the from_preset() method, provided the model is in a Keras specific format.

The idea to use any model from the Hub needs the following workflow:

  1. The model (backbone) architecture should be documented in both Hugging Face and Keras NLP.
  2. We should be able to build a Keras NPL model from the configuration file of the Hugging Face hub.
  3. We should also be able to port the weights (safetensors) from the Hugging Face Hub into the Keras NLP model.
  4. The tokenizer should be ported as well.

How is this important?

This opens up a lot of opportunities. If we have a model architecture defined in Keras NLP and Hugging Face, we are not tied to any platform. One can use Hugging Face to fine tune the model and upload it to the Hub, where as one can download the fine tuned model and use it as a Keras NLP model with any backends (TensorFlow, PyTorch, JAX)

This PR will make it possible for us to load checkpoints from the Hugging Face Hub into Keras NLP in a format agnostic manner.

How to use

With the current state of the PR one can use the following code to load any Gemma or Llama3 checkpoints as a Keras NLP model.

! pip install -U -q git+https://github.com/ariG23498/keras-nlp@aritra/hf-port

import keras_nlp

causal_lm = keras_nlp.models.GemmaCausalLM.from_preset(
    "hf://google/gemma-7b",
)

causal_lm = keras_nlp.models.Llama3CausalLM.from_preset(
    "hf://meta-llama/Meta-Llama-3-8B",
)

Acknowledgements

Thanks to Matthew Carrigan for his early feedback on the APIs and ideas.

The current PR is based on top of @mattdangerw's work. One can find his code here.

CC: @mattdangerw

@github-actions github-actions bot added the Gemma Gemma model specific issues label Jun 7, 2024
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Left some initial comments. The main thing we need is a testing plan.

Are there some small checkpoints we could just load over the network from huggingface? We'd annotate this tests as "large" tests but do this elsewhere.

Other ideas? We definitely want something that will raise alarm bells here if we break the conversion path.

keras_nlp/src/utils/transformers_model_utils/__init__.py Outdated Show resolved Hide resolved
keras_nlp/src/models/backbone.py Outdated Show resolved Hide resolved
@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Jun 17, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 17, 2024
@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Jun 18, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 18, 2024
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Left some comments, but they are all pretty small nits.

Test are running now. We can ignore Keras 2 failures, they are unrelated. But any other failures that pop up would be worth digging into.

Thanks for this!

keras_nlp/src/utils/preset_utils.py Outdated Show resolved Hide resolved
keras_nlp/src/utils/transformers/convert_gemma.py Outdated Show resolved Hide resolved
keras_nlp/src/utils/transformers/convert_gemma_test.py Outdated Show resolved Hide resolved
keras_nlp/src/utils/transformers/convert_gemma_test.py Outdated Show resolved Hide resolved
keras_nlp/src/utils/transformers/convert_llama3.py Outdated Show resolved Hide resolved
),
)
port_weight(
keras_variable=decoder_layer._self_attention_layer._key_dense.variables[
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment for all of these, can't you use kernel instead of variables[0]? And bias, etc. The variable names will be a lot more readable.

Copy link
Collaborator Author

@ariG23498 ariG23498 Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason I was shying away from using kernel and bias is because, it is not consistent with all the layers.
An embedding layer has embeddings, a normalization layer has scale and a dense layer has kernel. If you want me to make the changes, I would need to run the scripts and be sure about the variable names.

What would you like me to do?

image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good to leave as is then! Maybe something we can do as a follow up. Should be pretty easy to do safely now that we have testing in place.

keras_nlp/src/utils/transformers/safetensor_utils.py Outdated Show resolved Hide resolved
@mattdangerw mattdangerw changed the title [WIP] loading a safetensor checkpoint from Hugging Face hub Covert a safetensor checkpoint from Hugging Face hub Jun 18, 2024
@ariG23498
Copy link
Collaborator Author

@mattdangerw the tests that fail is likely due to the safetensors package that is not installed on the machines. How do you want to bypass that?

@mattdangerw
Copy link
Member

@mattdangerw the tests that fail is likely due to the safetensors package that is not installed on the machines. How do you want to bypass that?

Adding to requirements-common.txt should do it I think?

@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 21, 2024
@ariG23498 ariG23498 added the kokoro:force-run Runs Tests on GPU label Jun 21, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 21, 2024
@ariG23498 ariG23498 added the kokoro:force-run Runs Tests on GPU label Jun 21, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 21, 2024
@ariG23498 ariG23498 added the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@ariG23498 ariG23498 added the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@ariG23498
Copy link
Collaborator Author

@mattdangerw all the test pass 🥳

Copy link
Member

@SamanehSaadat SamanehSaadat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Aritra! Looks great! Just left some minor comments!

return load_gemma_backbone(cls, preset, load_weights)
if cls.__name__ == "Llama3Backbone":
return load_llama3_backbone(cls, preset, load_weights)
raise ValueError(f"No conversion huggingface/transformers to {cls}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user doesn't know that conversion is required to load a transformers checkpoint in Keras and try to load a transformers checkpoint that doesn't have conversion, they'll end up here, right? Similar to #1574
In that case, I think it'd be nice to have an error message helping the user to know that if conversion is not supported, they can switch to loading a Keras checkpoint if available.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified the Value Error message. Let me know if that was what you wanted.

keras_nlp/src/utils/transformers/convert.py Outdated Show resolved Hide resolved
Returns:
backbone: Initialized Keras model backbone.
"""
transformers_config = load_config(preset, "config.json")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a constant for config.json here. We have a plan to change the name of this file in the future so using the constant would make future changes easier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the config.json comes from the Hugging Face Repository. I have added another constant to support this file name, and now am using the constant. Does the current implementation look good?

keras_nlp/src/utils/transformers/convert_gemma.py Outdated Show resolved Hide resolved
keras_nlp/src/utils/transformers/convert_llama3.py Outdated Show resolved Hide resolved
@ariG23498 ariG23498 added the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@ariG23498 ariG23498 added the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@ariG23498 ariG23498 added the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 24, 2024
@SamanehSaadat SamanehSaadat merged commit c459519 into keras-team:master Jun 24, 2024
8 checks passed
@sayakpaul
Copy link

sayakpaul commented Jun 25, 2024

The PR title gives me an impression that any checkpoint with the safetensors extension would be supported with this PR.

But from what I understand it is probably not doing that. It is adding support for loading Transformers (the library) formatted checkpoints that have safetensors extension from the Hub provided certain constraints.

If so, it might be nice to modify the title accordingly. I am happy to stand corrected if my understanding is wrong.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants