Skip to content
20 changes: 20 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,17 +2261,31 @@ def test_forward_with_logits_to_keep(self):
torch.testing.assert_close(all_logits[:, -1:, :], last_token_logits, rtol=1e-5, atol=1e-5)

def test_generate_with_and_without_position_ids(self):
ran_any_model = False
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
model_forward_args = inspect.signature(model.forward).parameters

has_3d_rope_positions = any(
hasattr(module, "get_rope_index")
for module in (
model,
getattr(model, "model", None),
getattr(model, "language_model", None),
getattr(model, "text_model", None),
)
)
if has_3d_rope_positions:
continue
Comment on lines +2279 to +2280

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can skip right away because usually all models in the set will be the same arch


if "position_ids" not in model_forward_args or "input_ids" not in inputs_dict:
self.skipTest("This model doesn't use `position_ids`")

if config.is_encoder_decoder:
self.skipTest("This model doesn't prepare `position_ids` in generate")

ran_any_model = True
input_ids = inputs_dict["input_ids"]
seq_length = input_ids.shape[1]
# ensure left padding
Expand Down Expand Up @@ -2304,6 +2318,12 @@ def test_generate_with_and_without_position_ids(self):
# and can continue adding new ids to the already passed position ids
self.assertListEqual(out_wo_positions.tolist(), out_w_positions.tolist())

if not ran_any_model:
self.skipTest(
"All model classes in this test use 3D RoPE positions (`get_rope_index`), for which 2D custom "
"`position_ids` may be accepted but are expected to produce invalid outputs."
)

def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = (
Expand Down