Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a warning for mismatched inputs structure in Functional #20170

Merged
merged 1 commit into from
Aug 27, 2024

Conversation

james77777778
Copy link
Contributor

There are several issues related to mismatched inputs structure:

As a result, it would be beneficial to add a warning for users.
Ultimately, we might want to raise an error when a mismatch occurs. Otherwise, it could lead to subtle issues if the inputs have the same shape and dtype, as the computation could be incorrect even though the code runs.

@codecov-commenter
Copy link

codecov-commenter commented Aug 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 79.34%. Comparing base (3cc4d44) to head (dddfbfe).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #20170   +/-   ##
=======================================
  Coverage   79.34%   79.34%           
=======================================
  Files         501      501           
  Lines       47319    47325    +6     
  Branches     8692     8694    +2     
=======================================
+ Hits        37544    37550    +6     
  Misses       8017     8017           
  Partials     1758     1758           
Flag Coverage Δ
keras 79.19% <100.00%> (+<0.01%) ⬆️
keras-jax 62.47% <100.00%> (+<0.01%) ⬆️
keras-numpy 57.65% <100.00%> (+0.07%) ⬆️
keras-tensorflow 63.85% <100.00%> (+<0.01%) ⬆️
keras-torch 62.50% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the improvement -- LGTM

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Aug 27, 2024
@fchollet fchollet merged commit cb233fa into keras-team:master Aug 27, 2024
7 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Aug 27, 2024
@james77777778 james77777778 deleted the add-warning branch August 28, 2024 01:28
@ThorvaldAagaard
Copy link

ThorvaldAagaard commented Oct 30, 2024

After upgrading Keras to latest version I am getting the warning, but can't figure out, what is wrong with the input

This is my code
image

and this is the output:

[<KerasTensor shape=(None, 42), dtype=float16, sparse=False, name=X_input>, <KerasTensor shape=(None, 15), dtype=float16, sparse=False, name=B_input>]                                                                                                                                            
x dtype: float16 shape: (1, 42)                                                                                                                                                                                                                                                                   
b dtype: float16 shape: (1, 15)                                                                                                                                                                                                                                                                   
UserWarning: The structure of `inputs` doesn't match the expected structure: ['X_input', 'B_input']. Received: the structure of inputs=('*', '*')                                             

Hard to see, what is wrong

I have tried with: result = model.predict({'X_input': x, 'B_input': b},verbose=0)
but got the same warning

Code is working fine

@LuigiCerone
Copy link

@ThorvaldAagaard Same here. Did you figure it out?

@ThorvaldAagaard
Copy link

ThorvaldAagaard commented Nov 6, 2024

I did not find a fix, except downgrading Tensorflow to 2.17
I assume it is a bug in the implemenation, but the fix is merged, so perhaps noone else but us are reading this.

There should be a way to disable these warnings.

The change is this

    def _maybe_warn_inputs_struct_mismatch(self, inputs):
        try:
            tree.assert_same_structure(
                inputs, self._inputs_struct, check_types=False
            )
        except:
            model_inputs_struct = tree.map_structure(
                lambda x: x.name, self._inputs_struct
            )
            inputs_struct = tree.map_structure(lambda x: "*", inputs)
            warnings.warn(
                "The structure of `inputs` doesn't match the expected "
                f"structure: {model_inputs_struct}. "
                f"Received: the structure of inputs={inputs_struct}"
            )

so probably
assert_same_structure
is the problem

@ThorvaldAagaard
Copy link

I found this, that can suppress all the warnings from Python

import warnings
warnings.filterwarnings("ignore")

It is working

@LuigiCerone
Copy link

LuigiCerone commented Nov 6, 2024

Simple script to reproduce the error:

keras==3.6.0
tensorflow==2.17.0
import numpy as np
from tensorflow.keras import layers, models

# Generate model
input_tensor_1 = layers.Input(shape=(10,), name='input_1')
input_tensor_2 = layers.Input(shape=(5,), name='input_2')

combined = layers.concatenate([input_tensor_1, input_tensor_2])

output = layers.Dense(1, activation='sigmoid')(combined)

model = models.Model(inputs=[input_tensor_1, input_tensor_2], outputs=output)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train
num_samples = 1000

X1 = np.random.rand(num_samples, 10)
X2 = np.random.rand(num_samples, 5)
y = np.random.randint(2, size=num_samples)

model.fit([X1, X2], y, epochs=5, batch_size=32, verbose=0)

# Prediction
X1_new = np.random.rand(5, 10)
X2_new = np.random.rand(5, 5)

model.predict([X1_new, X2_new], verbose=0)

which gives the mentioned error:

UserWarning: The structure of `inputs` doesn't match the expected structure: ['input_1', 'input_2']. Received: the structure of inputs=('*', '*')
  warnings.warn(

By debugging, it seems a simple list/tuple mismatch in my case, given by:

            tree.assert_same_structure(
                inputs, self._inputs_struct, check_types=False
            )

Fixed by using tuple as input for model:

model = models.Model(inputs=(input_tensor_1, input_tensor_2), outputs=output)

@james77777778
Copy link
Contributor Author

Hey @ThorvaldAagaard , @LuigiCerone
Thanks for the report. I have submitted a fix to suppress the warning for mismatched tuples and lists.

@ThorvaldAagaard
Copy link

Looking forward to the fix, I did not find a workaround for the predict method

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants