diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 0d44ac3ce449..7a1cceefe454 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -19,6 +19,8 @@ from functools import partial from typing import Callable, Optional, Tuple +import numpy as np + import flax.linen as nn import jax import jax.numpy as jnp @@ -212,15 +214,15 @@ """ -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: """ Shift input ids one token to the right. """ - shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) - shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + shifted_input_ids = np.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index 8435b2dc6892..ccbdcab8aab5 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -221,11 +221,11 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_ """ Shift input ids one token to the right. """ - shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) - shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + shifted_input_ids = np.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index fd8e64ca0a76..5673e06afe98 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -19,6 +19,8 @@ from functools import partial from typing import Callable, Optional, Tuple +import numpy as np + import flax.linen as nn import jax import jax.numpy as jnp @@ -217,20 +219,19 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int) -> jnp.ndarray Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not have a single `decoder_start_token_id` in contrast to other Bart-like models. """ - prev_output_tokens = jnp.array(input_ids).clone() + prev_output_tokens = np.array(input_ids).copy() assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` - prev_output_tokens = jnp.where(prev_output_tokens == -100, pad_token_id, input_ids) - index_of_eos = (jnp.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1) - decoder_start_tokens = jnp.array( - [prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)] + prev_output_tokens = np.where(prev_output_tokens == -100, pad_token_id, input_ids) + index_of_eos = (np.where(prev_output_tokens != pad_token_id, 1, 0).sum(axis=-1) - 1).reshape(-1, 1) + decoder_start_tokens = np.array( + [prev_output_tokens[i, eos_idx] for i, eos_idx in enumerate(index_of_eos)], dtype=np.int32 ).squeeze() - # for loop basically does jax-compatible version of prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() - for i in range(prev_output_tokens.shape[1], 0, -1): - prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., i), prev_output_tokens[:, i - 1]) - prev_output_tokens = jax.ops.index_update(prev_output_tokens, (..., 0), decoder_start_tokens) + + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].copy() + prev_output_tokens[:, 0] = decoder_start_tokens return prev_output_tokens diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 7e52dca9f733..dfaf0976e343 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -47,15 +47,16 @@ _TOKENIZER_FOR_DOC = "T5Tokenizer" -def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right +def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray: """ Shift input ids one token to the right. """ - shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) - shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) + shifted_input_ids = np.zeros_like(input_ids) + shifted_input_ids[:, 1:] = input_ids[:, :-1] + shifted_input_ids[:, 0] = decoder_start_token_id + shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids