diff --git a/tests/pipelines/test_pipelines_summarization.py b/tests/pipelines/test_pipelines_summarization.py index eb688d69c700..781716b5ba37 100644 --- a/tests/pipelines/test_pipelines_summarization.py +++ b/tests/pipelines/test_pipelines_summarization.py @@ -18,9 +18,10 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, SummarizationPipeline, + TFPreTrainedModel, pipeline, ) -from transformers.testing_utils import require_tf, require_torch, slow, torch_device +from transformers.testing_utils import get_gpu_count, require_tf, require_torch, slow, torch_device from transformers.tokenization_utils import TruncationStrategy from .test_pipelines_common import ANY, PipelineTestCaseMeta @@ -51,6 +52,7 @@ def run_pipeline_test(self, summarizer, _): ) self.assertEqual(outputs, [{"summary_text": ANY(str)}]) + # Some models (Switch Transformers, LED, T5, LongT5, etc) can handle long sequences. model_can_handle_longer_seq = [ "SwitchTransformersConfig", "T5Config", @@ -62,10 +64,16 @@ def run_pipeline_test(self, summarizer, _): "ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values) ] if model.config.__class__.__name__ not in model_can_handle_longer_seq: - # Switch Transformers, LED, T5, LongT5 can handle it. - # Too long. - with self.assertRaises(Exception): - outputs = summarizer("This " * 1000) + # Too long and exception is expected. + # For TF models, if the weights are initialized in GPU context, we won't get expected index error from + # the embedding layer. + if not ( + isinstance(model, TFPreTrainedModel) + and get_gpu_count() > 0 + and len(summarizer.model.trainable_weights) > 0 + ): + with self.assertRaises(Exception): + outputs = summarizer("This " * 1000) outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST) @require_torch