Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible JIT compilation bug with JAX #20165

Open
neo-alex opened this issue Aug 26, 2024 · 5 comments
Open

Possible JIT compilation bug with JAX #20165

neo-alex opened this issue Aug 26, 2024 · 5 comments

Comments

@neo-alex
Copy link
Contributor

I have the minimal code below to check that JIT-compiled model outputs match non-JIT ones:

import os
os.environ["KERAS_BACKEND"] = ...  # "jax" or "tensorflow"

import numpy as np
import keras

x = keras.ops.convert_to_tensor([
    [[1], [2], [3]],
    [[1], [2], [-99]],
    [[1], [-99], [-99]],
])

model = get_model()
model_output = model(x)  # this is NOT JIT-compiled

model.compile(jit_compile=True)
jit_model_output = model.predict_on_batch(x)  # this is JIT-compiled

assert np.allclose(model_output, jit_model_output)

For this example, assume that we want to create a model that will average x "line by line" above, ignoring the -99 values that we will mask.

  • When I test it with the get_model() function below, the test is successful both with "tensorflow" and "jax" backends:
def get_model():
    return keras.Sequential([
        keras.layers.Masking(-99),
        keras.layers.GlobalAveragePooling1D()
    ])
  • However, if I define the same sequence of operations but inside a custom layer, the test works for "tensorflow" but fails for "jax" backend because its JIT-compiled output is wrong (as if each masked value was virtually replaced by 0 instead of being ignored):
class MaskedGlobalAveragePooling1D(keras.layers.Layer):
    def __init__(self, mask_value, **kwargs):
        super().__init__(**kwargs)
        self.masking = keras.layers.Masking(mask_value)
        self.pooling = keras.layers.GlobalAveragePooling1D()

    def call(self, inputs):
        x = self.masking(inputs)
        return self.pooling(x)


def get_model():
    return keras.Sequential([
        MaskedGlobalAveragePooling1D(mask_value=-99)
    ])

Note: I know that using keras.layers.Masking inside a custom layer is not common (I actually need it for a more advanced use case), but I see no reason why it shouldn't work consistently across all backends.

I would appreciate any help fixing this bug, thank you!

@neo-alex
Copy link
Contributor Author

In the meantime, I also tried with "torch" backend and everything works fine, like with "tensorflow" (so the issue mentioned above seems specific to JAX with JIT compilation)

@sachinprasadhs
Copy link
Collaborator

I was able to reproduce the reported behavior here

@sachinprasadhs sachinprasadhs added keras-team-review-pending Pending review by a Keras team member. type:Bug backend:jax labels Aug 28, 2024
@mattdangerw mattdangerw removed the keras-team-review-pending Pending review by a Keras team member. label Aug 29, 2024
@neo-alex
Copy link
Contributor Author

My bad, I think the issue is solved if I change the call function of my MaskedGlobalAveragePooling1D to:

    def call(self, inputs):
        mask = self.masking.compute_mask(inputs)
        return self.pooling(inputs, mask=mask)

Still, I would argue that the original issue is rather tricky and can happen quite "silently" (it is at least unexpected that the output can differ across backends... I don't know if there would be an easy way to warn users somehow to mitigate it). By the way, it would be nice for the Masking & Padding guide to find its way back to the documentation (it seems to have disappeared from the Developer guides). Thanks!

@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Aug 30, 2024
@mehtamansi29
Copy link
Collaborator

mehtamansi29 commented Oct 25, 2024

Hi @neo-alex -

I have reproduce the issue with keras Masking layer get_model() function and also with MaskedGlobalAveragePooling1D subclassing in latest keras3.6.0. And it's working fine for both the case with jax and tensorflow backend.

Attached gist for your reference here.

@mehtamansi29 mehtamansi29 added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Oct 25, 2024
Copy link

github-actions bot commented Nov 9, 2024

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Nov 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants
@mattdangerw @neo-alex @sachinprasadhs @mehtamansi29 and others