-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible JIT compilation bug with JAX #20165
Comments
In the meantime, I also tried with "torch" backend and everything works fine, like with "tensorflow" (so the issue mentioned above seems specific to JAX with JIT compilation) |
My bad, I think the issue is solved if I change the call function of my MaskedGlobalAveragePooling1D to: def call(self, inputs):
mask = self.masking.compute_mask(inputs)
return self.pooling(inputs, mask=mask) Still, I would argue that the original issue is rather tricky and can happen quite "silently" (it is at least unexpected that the output can differ across backends... I don't know if there would be an easy way to warn users somehow to mitigate it). By the way, it would be nice for the Masking & Padding guide to find its way back to the documentation (it seems to have disappeared from the Developer guides). Thanks! |
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
I have the minimal code below to check that JIT-compiled model outputs match non-JIT ones:
For this example, assume that we want to create a model that will average x "line by line" above, ignoring the -99 values that we will mask.
get_model()
function below, the test is successful both with "tensorflow" and "jax" backends:Note: I know that using keras.layers.Masking inside a custom layer is not common (I actually need it for a more advanced use case), but I see no reason why it shouldn't work consistently across all backends.
I would appreciate any help fixing this bug, thank you!
The text was updated successfully, but these errors were encountered: