Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __call__(
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_last_hidden_state=encoder_hidden_states,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
Expand Down Expand Up @@ -363,8 +363,8 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDic
encoder_input_shape, decoder_input_shape = input_shape

# init input DeviceArrays
inputs = jnp.zeros(encoder_input_shape, dtype="i4")
attention_mask = jnp.ones_like(inputs)
inputs = jnp.zeros(encoder_input_shape, dtype="f4")
attention_mask = jnp.ones_like(inputs, dtype="i4")
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)

Expand Down Expand Up @@ -472,7 +472,7 @@ def encode(
return_dict = return_dict if return_dict is not None else self.config.return_dict

if attention_mask is None:
attention_mask = jnp.ones_like(inputs)
attention_mask = jnp.ones_like(inputs, dtype="i4")

# Handle any PRNG if needed
rngs = {}
Expand All @@ -485,7 +485,7 @@ def _encoder_forward(module, inputs, attention_mask, **kwargs):

outputs = self.module.apply(
{"params": params or self.params},
inputs=jnp.array(inputs, dtype="i4"),
inputs=jnp.array(inputs, dtype="f4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -680,7 +680,7 @@ def __call__(

# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(inputs)
attention_mask = jnp.ones_like(inputs, dtype="i4")

# prepare decoder inputs
if decoder_input_ids is None:
Expand All @@ -700,7 +700,7 @@ def __call__(

return self.module.apply(
{"params": params or self.params},
inputs=jnp.array(inputs, dtype="i4"),
inputs=jnp.array(inputs, dtype="f4"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit-gandhi - this is the reason for why the script fails. We should not convert the input to int here, we're dealing with float numbers

attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
Expand Down