diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index c215f0e0d3a5..10ecf44be84a 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -32,6 +32,7 @@ from ...pytorch_utils import ( apply_chunking_to_forward, find_pruneable_heads_and_indices, + is_torch_greater_or_equal_than_1_12, prune_linear_layer, ) from ...utils import ( @@ -46,6 +47,12 @@ logger = logging.get_logger(__name__) +if not is_torch_greater_or_equal_than_1_12: + logger.warning( + f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use " + "TapasModel. Please upgrade torch." + ) + _CONFIG_FOR_DOC = "TapasConfig" _CHECKPOINT_FOR_DOC = "google/tapas-base" diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index f076eb445cb6..839261663457 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -29,6 +29,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") +is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10") is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11") diff --git a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py index 01e6ceef9e90..8beddc0abad1 100644 --- a/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py +++ b/tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py @@ -37,6 +37,9 @@ GPTBigCodeModel, ) from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 +else: + is_torch_greater_or_equal_than_1_12 = False class GPTBigCodeModelTester: @@ -530,6 +533,10 @@ class GPTBigCodeMHAModelTest(GPTBigCodeModelTest): multi_query = False +@unittest.skipIf( + not is_torch_greater_or_equal_than_1_12, + reason="`GPTBigCode` checkpoints use `PytorchGELUTanh` which requires `torch>=1.12.0`.", +) @slow @require_torch class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index 644307e3f917..619a5d26128f 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -60,6 +60,9 @@ reduce_mean, reduce_sum, ) + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12 +else: + is_torch_greater_or_equal_than_1_12 = False class TapasModelTester: @@ -408,6 +411,7 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict +@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch class TapasModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( @@ -562,6 +566,7 @@ def prepare_tapas_batch_inputs_for_training(): return table, queries, answer_coordinates, answer_text, float_answer +@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch class TapasModelIntegrationTest(unittest.TestCase): @cached_property @@ -916,6 +921,7 @@ def test_inference_classification_head(self): # Below: tests for Tapas utilities which are defined in modeling_tapas.py. # These are based on segmented_tensor_test.py of the original implementation. # URL: https://github.com/google-research/tapas/blob/master/tapas/models/segmented_tensor_test.py +@unittest.skipIf(not is_torch_greater_or_equal_than_1_12, reason="Tapas is only available in torch v1.12+") @require_torch class TapasUtilitiesTest(unittest.TestCase): def _prepare_tables(self):