diff --git a/examples/pytorch/test_pytorch_examples.py b/examples/pytorch/test_pytorch_examples.py index e7cc2d51c006..3fab6df81c1b 100644 --- a/examples/pytorch/test_pytorch_examples.py +++ b/examples/pytorch/test_pytorch_examples.py @@ -18,6 +18,7 @@ import logging import os import sys +import unittest from unittest.mock import patch from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining @@ -613,6 +614,10 @@ def test_run_semantic_segmentation(self): self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1) @patch.dict(os.environ, {"WANDB_DISABLED": "true"}) + @unittest.skipIf( + backend_device_count(torch_device) > 1, + "TODO @qubvel, index out of bounds for bounding boxes when running on multi-accelerator", + ) def test_run_object_detection(self): tmp_dir = self.get_auto_remove_tmp_dir() testargs = f"""