-
Notifications
You must be signed in to change notification settings - Fork 305
Convert T5 to Keras 3 #1274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert T5 to Keras 3 #1274
Conversation
|
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Just minor comments.
| from keras_nlp.backend import ops | ||
|
|
||
|
|
||
| def shape_list(tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can probably just remove this whole function and use ops.shape where ever it was used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, sorry left over while debugging!
| mask_positions, | ||
| mask_ids, | ||
| ) = tf_text.mask_language_model( | ||
| (token_ids, mask_positions, mask_ids,) = tf_text.mask_language_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks like you are using an older version of black maybe?
if you remove the trailing comma it would reformat to one line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed! Yeah just needed to update black
|
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
|
this failure here is xlm-roberta causing an occasional oom. maybe we should just comment that test out for now, it's probably the big embedding eating all available ram. anyway, unrelated |
This PR replaces all TF ops with Keras ops, making T5 work multi-backend.