diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 34617810cfc0..682cd9e171f7 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2261,17 +2261,31 @@ def test_forward_with_logits_to_keep(self): torch.testing.assert_close(all_logits[:, -1:, :], last_token_logits, rtol=1e-5, atol=1e-5) def test_generate_with_and_without_position_ids(self): + ran_any_model = False for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() model = model_class(config).to(torch_device).eval() model_forward_args = inspect.signature(model.forward).parameters + has_3d_rope_positions = any( + hasattr(module, "get_rope_index") + for module in ( + model, + getattr(model, "model", None), + getattr(model, "language_model", None), + getattr(model, "text_model", None), + ) + ) + if has_3d_rope_positions: + continue + if "position_ids" not in model_forward_args or "input_ids" not in inputs_dict: self.skipTest("This model doesn't use `position_ids`") if config.is_encoder_decoder: self.skipTest("This model doesn't prepare `position_ids` in generate") + ran_any_model = True input_ids = inputs_dict["input_ids"] seq_length = input_ids.shape[1] # ensure left padding @@ -2304,6 +2318,12 @@ def test_generate_with_and_without_position_ids(self): # and can continue adding new ids to the already passed position ids self.assertListEqual(out_wo_positions.tolist(), out_w_positions.tolist()) + if not ran_any_model: + self.skipTest( + "All model classes in this test use 3D RoPE positions (`get_rope_index`), for which 2D custom " + "`position_ids` may be accepted but are expected to produce invalid outputs." + ) + def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): input_batch_size = int(output.sequences.shape[0] / num_return_sequences) internal_batch_size = (