-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Covert a safetensor
checkpoint from Hugging Face hub
#1662
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Left some initial comments. The main thing we need is a testing plan.
Are there some small checkpoints we could just load over the network from huggingface? We'd annotate this tests as "large" tests but do this elsewhere.
Other ideas? We definitely want something that will raise alarm bells here if we break the conversion path.
… to match the keras weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Left some comments, but they are all pretty small nits.
Test are running now. We can ignore Keras 2 failures, they are unrelated. But any other failures that pop up would be worth digging into.
Thanks for this!
), | ||
) | ||
port_weight( | ||
keras_variable=decoder_layer._self_attention_layer._key_dense.variables[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment for all of these, can't you use kernel
instead of variables[0]
? And bias
, etc. The variable names will be a lot more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only reason I was shying away from using kernel
and bias
is because, it is not consistent with all the layers.
An embedding layer has embeddings
, a normalization layer has scale
and a dense layer has kernel
. If you want me to make the changes, I would need to run the scripts and be sure about the variable names.
What would you like me to do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good to leave as is then! Maybe something we can do as a follow up. Should be pretty easy to do safely now that we have testing in place.
safetensor
checkpoint from Hugging Face hubsafetensor
checkpoint from Hugging Face hub
@mattdangerw the tests that fail is likely due to the |
Adding to requirements-common.txt should do it I think? |
@mattdangerw all the test pass 🥳 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Aritra! Looks great! Just left some minor comments!
return load_gemma_backbone(cls, preset, load_weights) | ||
if cls.__name__ == "Llama3Backbone": | ||
return load_llama3_backbone(cls, preset, load_weights) | ||
raise ValueError(f"No conversion huggingface/transformers to {cls}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a user doesn't know that conversion is required to load a transformers checkpoint in Keras and try to load a transformers checkpoint that doesn't have conversion, they'll end up here, right? Similar to #1574
In that case, I think it'd be nice to have an error message helping the user to know that if conversion is not supported, they can switch to loading a Keras checkpoint if available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have modified the Value Error message. Let me know if that was what you wanted.
Returns: | ||
backbone: Initialized Keras model backbone. | ||
""" | ||
transformers_config = load_config(preset, "config.json") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a constant for config.json
here. We have a plan to change the name of this file in the future so using the constant would make future changes easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the config.json
comes from the Hugging Face Repository. I have added another constant to support this file name, and now am using the constant. Does the current implementation look good?
The PR title gives me an impression that any checkpoint with the But from what I understand it is probably not doing that. It is adding support for loading Transformers (the library) formatted checkpoints that have If so, it might be nice to modify the title accordingly. I am happy to stand corrected if my understanding is wrong. |
This reverts commit c459519.
Problem Statement
With Keras NLP, you can download models from the Hugging Face Hub using the
from_preset()
method, provided the model is in a Keras specific format.The idea to use any model from the Hub needs the following workflow:
How is this important?
This opens up a lot of opportunities. If we have a model architecture defined in Keras NLP and Hugging Face, we are not tied to any platform. One can use Hugging Face to fine tune the model and upload it to the Hub, where as one can download the fine tuned model and use it as a Keras NLP model with any backends (TensorFlow, PyTorch, JAX)
This PR will make it possible for us to load checkpoints from the Hugging Face Hub into Keras NLP in a format agnostic manner.
How to use
With the current state of the PR one can use the following code to load any Gemma or Llama3 checkpoints as a Keras NLP model.
Acknowledgements
Thanks to Matthew Carrigan for his early feedback on the APIs and ideas.
The current PR is based on top of @mattdangerw's work. One can find his code here.
CC: @mattdangerw