Skip to content

Conversation

@james77777778
Copy link
Collaborator

I encountered this issue while trying to create a float8 training/inference example for keras.io.

In Keras3 (I haven't verified this in Keras2), training argument isn't propagated when using subclasses like keras_nlp.layers.TransformerDecoder, unless we explicitly expose training=None in the signature of call.

I've added the tests to confirm that this issue has been resolved and they can only pass with this PR.

@james77777778
Copy link
Collaborator Author

Kindly ping @mattdangerw

This issue needs to be fixed for the float8 example on keras.io
keras-team/keras-io#1858

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mattdangerw mattdangerw merged commit b043a4f into keras-team:master May 17, 2024
@james77777778 james77777778 deleted the fix-training-args branch June 21, 2024 05:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants