diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index aff3953b8407..6e36703cf9d7 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -310,7 +310,7 @@ def __call__( decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_last_hidden_state=encoder_hidden_states, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) @@ -363,8 +363,8 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic encoder_input_shape, decoder_input_shape = input_shape # init input DeviceArrays - inputs = jnp.zeros(encoder_input_shape, dtype="i4") - attention_mask = jnp.ones_like(inputs) + inputs = jnp.zeros(encoder_input_shape, dtype="f4") + attention_mask = jnp.ones_like(inputs, dtype="i4") decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) @@ -472,7 +472,7 @@ def encode( return_dict = return_dict if return_dict is not None else self.config.return_dict if attention_mask is None: - attention_mask = jnp.ones_like(inputs) + attention_mask = jnp.ones_like(inputs, dtype="i4") # Handle any PRNG if needed rngs = {} @@ -485,7 +485,7 @@ def _encoder_forward(module, inputs, attention_mask, **kwargs): outputs = self.module.apply( {"params": params or self.params}, - inputs=jnp.array(inputs, dtype="i4"), + inputs=jnp.array(inputs, dtype="f4"), attention_mask=jnp.array(attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -680,7 +680,7 @@ def __call__( # prepare encoder inputs if attention_mask is None: - attention_mask = jnp.ones_like(inputs) + attention_mask = jnp.ones_like(inputs, dtype="i4") # prepare decoder inputs if decoder_input_ids is None: @@ -700,7 +700,7 @@ def __call__( return self.module.apply( {"params": params or self.params}, - inputs=jnp.array(inputs, dtype="i4"), + inputs=jnp.array(inputs, dtype="f4"), attention_mask=jnp.array(attention_mask, dtype="i4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),