-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[P1] Loading REFT fro RoBERTa Models #86
Comments
hey @hSterz thanks for your question! the reason is that to use # get reft model
reft_config = pyreft.ReftConfig(representations={
"layer": 15, "component": "block_output",
# alternatively, you can specify as string component access,
# "component": "model.layers[0].output",
"low_rank_dimension": 4,
"intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()
"""
trainable intervention params: 32,772 || trainable model params: 0
model params: 6,738,415,616 || trainable%: 0.00048634578018881287
""" in our actual code, you can see how we did it as well here: |
Thank you for the reply @frankaging My question is how can I load a REFT module added and trained trained as described by your example? |
@hSterz got it! so, if the model is natively supported by pyvene (supported model can be found here), you can load the model as, reft_model = pyreft.ReftModel.load("<your_directory>", model) if the model is not supported by pyvene, you have either (1) add the support in pyvene and reinstall pyvene, or (2) reinitialize the pyreft model, and load manually by yourself. All the interventions can be accessed as let me know if these help. |
Hi @hSterz, how did you finally go about this? I am facing the same error. |
I was training and saving REFT modules for the RoBERTa model. But loading them seems to be not possible with the current implementation. I get the following Error:
It looks like
type_to_dimension_mapping
does not include RoBERTa or does RoBERTa fall under one of the existing models?The text was updated successfully, but these errors were encountered: