diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 5048f0c3eef8..25d33b1b6ca9 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -18,6 +18,8 @@ from typing import List, Optional, Union +import numpy as np + from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images from ...processing_utils import ( @@ -310,7 +312,8 @@ def __call__( return_data = {**inputs, "pixel_values": pixel_values} if return_token_type_ids: - labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100) + labels = np.array(inputs["input_ids"]) + labels[np.array(inputs["token_type_ids"]) == 0] = -100 return_data.update({"labels": labels}) return BatchFeature(data=return_data, tensor_type=return_tensors) diff --git a/tests/models/paligemma/test_processor_paligemma.py b/tests/models/paligemma/test_processor_paligemma.py index 8ccae4588750..56e74928925d 100644 --- a/tests/models/paligemma/test_processor_paligemma.py +++ b/tests/models/paligemma/test_processor_paligemma.py @@ -62,6 +62,20 @@ def test_image_seq_length(self): ) self.assertEqual(len(inputs["input_ids"][0]), 112) + @require_torch + def test_call_with_suffix(self): + input_str = "lower newer" + suffix = "upper older longer string" + image_input = self.prepare_image_inputs() + processor = self.get_processor() + inputs = processor(text=input_str, images=image_input, suffix=suffix) + self.assertTrue("labels" in inputs) + self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0])) + + inputs = processor(text=input_str, images=image_input, suffix=suffix, return_tensors="pt") + self.assertTrue("labels" in inputs) + self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0])) + def test_text_with_image_tokens(self): image_processor = self.get_component("image_processor") tokenizer = self.get_component("tokenizer")