diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index d44bc80363d0..2d56d6d1c9f7 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -372,6 +372,10 @@ def _check_shapes(shape_1, shape2): def shift_tokens_right(input_ids, pad_token_id): """Shift input ids one token to the right, and wrap the last non pad token (usually ).""" + + # replace possible -100 values in labels by `pad_token_id` + input_ids.masked_fill_(input_ids == -100, pad_token_id) + prev_output_tokens = input_ids.clone() index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()