Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Falied to load gru layer weights to gru cell, Layer 'gru_cell' expected 3 variables, but received 0 variables during loading #20407

Open
victorVoice opened this issue Oct 25, 2024 · 4 comments

Comments

@victorVoice
Copy link

Using tensorflow 2.16.1 with keras 3.5.0 falied to load pretrained gru layers weights to a gru cell.
the tow layer are defined as below

For gru layers:
t_rnn_1 = keras.layers.GRU(units=64, return_sequences=True)(t_in_1)
t_rnn_2 = keras.layers.GRU(units=64, return_sequences=True)(t_rnn_1)
t_dense_c = keras.layers.Dense(80)(t_rnn_2)
t_dense_c = tf.keras.layers.ReLU(max_value=6.)(t_dense_c)

For gru cells:
t_rnn_1, cell_out1 = keras.layers.GRUCell(units=64)(t_in_1, states=cell_in1)
t_rnn_2, cell_out2 = keras.layers.GRUCell(units=64)(t_rnn_1, states=cell_in2)
t_dense_2= keras.layers.Dense(80)(t_rnn_2)
t_dense_2 = tf.keras.layers.ReLU(max_value=6.)(t_dense_2)

when loading got flowing error message

Traceback (most recent call last):
File "/home/victoryu/project/se_tf/subband_model_streaming.py", line 319, in
tf_model.create_tf_lite_model(weights_file=args.ckpt, target_name='./crn_cplx')
File "/home/victoryu/project/se_tf/subband_model_streaming.py", line 128, in create_tf_lite_model
self.model.load_weights(weights_file)
File "/home/victoryu/miniconda3/envs/tf2.16/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/victoryu/miniconda3/envs/tf2.16/lib/python3.11/site-packages/keras/src/saving/saving_lib.py", line 593, in _raise_loading_failure
raise ValueError(msg)
ValueError: A total of 2 objects could not be loaded. Example error message for object :

Layer 'gru_cell' expected 3 variables, but received 0 variables during loading. Expected: ['kernel', 'recurrent_kernel', 'bias']

It works fine when using tensorflow 2.13.0 + keras 2.13.1.

When visulize the weight.h5 file the differnce between to layers are as below

image

image

wonder is the cause the problem, and how to fix it.

@mehtamansi29
Copy link
Collaborator

Hi @victorVoice -

Thanks for reporting the issue. Can you help me what you defined here t_in_1 or any full sample code for both gru or gru_cell layer ?

@victorVoice
Copy link
Author

victorVoice commented Oct 25, 2024

@mehtamansi29 Sure,
t_in_1 for gru layer is a tensor with shape [batchsize, time_steps, feature_dims] in the acture model is like a tensor with shape [32, 63, 80]
for gru_cell layer, since it process 1 time step at each time the t_in_1 here is [batch_siz,feature_dims], like [32, 80]

here is some sample code for gru cell layer

inp = keras.Input(batch_shape=(1, 5, 16))
cell_in1 = keras.Input(batch_shape=(1, 64))
cell_in2 = keras.Input(batch_shape=(1, 64))

t_in_1 = keras.layers.Reshape([5 * 16])(inp)
t_rnn_1, cell_out1 = keras.layers.GRUCell(units=64)(t_in_1, states=cell_in1)
t_rnn_2, cell_out2 = keras.layers.GRUCell(units=64)(t_rnn_1, states=cell_in2)
t_dense_2= keras.layers.Dense(80)(t_rnn_2)
t_dense_2 = tf.keras.layers.ReLU(max_value=6.)(t_dense_2)
s2_out = keras.layers.Reshape([1, 5, 16])(t_dense_2)

here is the code for gru layers

inp = keras.Input(batch_shape=(32, 63, 5, 16))

t_in_1 = keras.layers.Reshape([64, 5 * 16])(inp)
t_rnn_1 =  keras.layers.GRU(units=64, return_sequences=True)(t_in_1)
t_rnn_2 =  keras.layers.GRU(units=64, return_sequences=True)(t_rnn_1)
t_dense_2= keras.layers.Dense(80)(t_rnn_2)
t_dense_2 = tf.keras.layers.ReLU(max_value=6.)(t_dense_2)
s2_out = keras.layers.Reshape([1, 63, 5, 16])(t_dense_2)

@mehtamansi29
Copy link
Collaborator

Hi @victorVoice -

Thanks for the sample code. I replicate the sample code with GRU layer or GRU_cell in latest keras(3.6.0) and it is working fine for me.
Attached gist for the reference.

@victorVoice
Copy link
Author

@mehtamansi29 Thanks i will try keras(3.6.0) first, thx for the help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants