Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

layers.GRU returns wrong shaped output with GPU #20173

Open
Jonii opened this issue Aug 27, 2024 · 3 comments
Open

layers.GRU returns wrong shaped output with GPU #20173

Jonii opened this issue Aug 27, 2024 · 3 comments
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug

Comments

@Jonii
Copy link

Jonii commented Aug 27, 2024

I opened this on tensorflow repo, and was told to move it here: tensorflow/tensorflow#74475

The short of it, gru, at least on google colab(keras 3.4.1) returns wrong things when run with gpu available.

Minimal way to reproduce here:

import tensorflow as tf

class TestModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.gru = tf.keras.layers.GRU(10, return_sequences=True, return_state=True)

    def call(self, inputs):
        return self.gru(inputs)

# Create and test the model
model = TestModel()
test_input = tf.random.uniform((2, 3, 5))  # Batch size = 2, sequence length = 3, feature size = 5
output = model(test_input)
print("Output types and shapes:", [(type(o), o.shape) for o in output])

This prints

With GPU:

Output types and shapes: [(<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([2, 3, 10])), (<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([10])), (<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([10]))]

With CPU:

Output types and shapes: [(<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([2, 3, 10])), (<class 'tensorflow.python.framework.ops.EagerTensor'>, TensorShape([2, 10]))]

CPU behavior seems correct.

**Edited to add, I do not have the ability to test gpu behavior outside of google colab, so this might be a bug that's been fixed on the latest version, or due to colab-specific misconfiguration.

@sachinprasadhs
Copy link
Collaborator

sachinprasadhs commented Aug 28, 2024

I was able to reproduce the reported behavior, attaching the Gist here

with the Torch backend, it's producing the expected outcome as below
Output types and shapes: [(<class 'torch.Tensor'>, torch.Size([2, 3, 10])), (<class 'torch.Tensor'>, torch.Size([2, 10]))]

@sachinprasadhs sachinprasadhs added the keras-team-review-pending Pending review by a Keras team member. label Aug 28, 2024
@mattdangerw mattdangerw removed the keras-team-review-pending Pending review by a Keras team member. label Aug 29, 2024
@mattdangerw
Copy link
Member

Probably and issue with the cudnn specific implementation on the tf backend, which is pretty dense. I will take a look.

@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Aug 29, 2024
@AdityaMayukhSom
Copy link

Similar issue happening in case of running Keras with Tensorflow backend on desktop. Hidden states of the individual element in a batch are returned as a tuple of the GRU output and not as a Tensor with first dimension equal to batch size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:Bug
Projects
None yet
Development

No branches or pull requests

7 participants
@mattdangerw @Jonii @sachinprasadhs @AdityaMayukhSom and others