diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py index 010fef113a28..9fcc25b6d24f 100644 --- a/tests/models/xglm/test_modeling_xglm.py +++ b/tests/models/xglm/test_modeling_xglm.py @@ -428,8 +428,14 @@ def test_xglm_sample(self): output_ids = model.generate(input_ids, do_sample=True, num_beams=1) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) - EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy" - self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + EXPECTED_OUTPUT_STRS = [ + # TODO: remove this once we move to torch 2.0 + # torch 1.13.1 + cu116 + "Today is a nice day and the sun is shining. A nice day with warm rainy", + # torch 2.0 + cu117 + "Today is a nice day and the water is still cold. We just stopped off for some fresh", + ] + self.assertIn(output_str, EXPECTED_OUTPUT_STRS) @slow def test_xglm_sample_max_time(self):