From 185c61985430ae1a4cbad3aba324b10cd4d50727 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 30 Jun 2022 09:36:49 +0200 Subject: [PATCH 1/9] First draft --- .../models/tapas/modeling_tapas.py | 42 +++++++++---------- tests/models/tapas/test_modeling_tapas.py | 11 +---- 2 files changed, 21 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 0b65e84ca7ac..23cb12a21a6f 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -22,6 +22,8 @@ from typing import Optional, Tuple import torch + +# torch.autograd.set_detect_anomaly(True) import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -34,27 +36,14 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_scatter_available, logging, replace_return_docstrings, - requires_backends, ) from .configuration_tapas import TapasConfig logger = logging.get_logger(__name__) -# soft dependency -if is_scatter_available(): - try: - from torch_scatter import scatter - except OSError: - logger.error( - "TAPAS models are not usable since `torch_scatter` can't be loaded. " - "It seems you have `torch_scatter` installed with the wrong CUDA version. " - "Please try to reinstall it following the instructions here: https://github.com/rusty1s/pytorch_scatter." - ) - _CONFIG_FOR_DOC = "TapasConfig" _TOKENIZER_FOR_DOC = "TapasTokenizer" _TOKENIZER_FOR_DOC = "google/tapas-base" @@ -862,7 +851,6 @@ class TapasModel(TapasPreTrainedModel): """ def __init__(self, config, add_pooling_layer=True): - requires_backends(self, "scatter") super().__init__(config) self.config = config @@ -1797,12 +1785,22 @@ def _segment_reduce(values, index, segment_reduce_fn, name): # changed "view" by "reshape" in the following line flat_values = values.reshape(flattened_shape.tolist()) - segment_means = scatter( - src=flat_values, - index=flat_index.indices.long(), - dim=0, - dim_size=int(flat_index.num_segments), - reduce=segment_reduce_fn, + src = flat_values + index = flat_index.indices.long() + dim = 0 + dim_size = int(flat_index.num_segments) + + if segment_reduce_fn == "sum": + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + out = torch.zeros(size, dtype=src.dtype, device=src.device) + elif segment_reduce_fn: + pass + + out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype) + segment_means = out.scatter_reduce( + dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False ) # Unflatten the values. @@ -1900,7 +1898,7 @@ def reduce_max(values, index, name="segmented_reduce_max"): output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. """ - return _segment_reduce(values, index, "max", name) + return _segment_reduce(values, index, "amax", name) def reduce_min(values, index, name="segmented_reduce_min"): @@ -1927,7 +1925,7 @@ def reduce_min(values, index, name="segmented_reduce_min"): output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. """ - return _segment_reduce(values, index, "min", name) + return _segment_reduce(values, index, "amin", name) # End of everything related to segmented tensors diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index 271a5efc9616..d6726f44490f 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -32,13 +32,7 @@ is_torch_available, ) from transformers.models.auto import get_values -from transformers.testing_utils import ( - require_scatter, - require_tensorflow_probability, - require_torch, - slow, - torch_device, -) +from transformers.testing_utils import require_tensorflow_probability, require_torch, slow, torch_device from transformers.utils import cached_property from ...test_configuration_common import ConfigTester @@ -414,7 +408,6 @@ def prepare_config_and_inputs_for_common(self): @require_torch -@require_scatter class TapasModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( @@ -553,7 +546,6 @@ def prepare_tapas_batch_inputs_for_training(): @require_torch -@require_scatter class TapasModelIntegrationTest(unittest.TestCase): @cached_property def default_tokenizer(self): @@ -907,7 +899,6 @@ def test_inference_classification_head(self): # Below: tests for Tapas utilities which are defined in modeling_tapas.py. # These are based on segmented_tensor_test.py of the original implementation. # URL: https://github.com/google-research/tapas/blob/master/tapas/models/segmented_tensor_test.py -@require_scatter class TapasUtilitiesTest(unittest.TestCase): def _prepare_tables(self): """Prepares two tables, both with three distinct rows. From a9c6bedcc13f49133bbce5b3cd155190c9e8d3ff Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 30 Jun 2022 14:39:28 +0200 Subject: [PATCH 2/9] Remove scatter dependency --- src/transformers/__init__.py | 58 +++++++------------ .../models/tapas/modeling_tapas.py | 13 ----- src/transformers/utils/dummy_pt_objects.py | 42 ++++++++++++++ 3 files changed, 62 insertions(+), 51 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a3ce3fd1eb2e..59762e678098 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -757,28 +757,6 @@ ] ) -try: - if not is_scatter_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils import dummy_scatter_objects - - _import_structure["utils.dummy_scatter_objects"] = [ - name for name in dir(dummy_scatter_objects) if not name.startswith("_") - ] -else: - _import_structure["models.tapas"].extend( - [ - "TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", - "TapasForMaskedLM", - "TapasForQuestionAnswering", - "TapasForSequenceClassification", - "TapasModel", - "TapasPreTrainedModel", - "load_tf_weights_in_tapas", - ] - ) - # PyTorch-backed objects try: @@ -1952,6 +1930,17 @@ "Swinv2PreTrainedModel", ] ) + _import_structure["models.tapas"].extend( + [ + "TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", + "TapasForMaskedLM", + "TapasForQuestionAnswering", + "TapasForSequenceClassification", + "TapasModel", + "TapasPreTrainedModel", + "load_tf_weights_in_tapas", + ] + ) _import_structure["models.t5"].extend( [ "T5_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3736,22 +3725,6 @@ TableTransformerPreTrainedModel, ) - try: - if not is_scatter_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from .utils.dummy_scatter_objects import * - else: - from .models.tapas import ( - TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, - TapasForMaskedLM, - TapasForQuestionAnswering, - TapasForSequenceClassification, - TapasModel, - TapasPreTrainedModel, - load_tf_weights_in_tapas, - ) - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() @@ -4709,6 +4682,15 @@ T5PreTrainedModel, load_tf_weights_in_t5, ) + from .models.tapas import ( + TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, + TapasForMaskedLM, + TapasForQuestionAnswering, + TapasForSequenceClassification, + TapasModel, + TapasPreTrainedModel, + load_tf_weights_in_tapas, + ) from .models.time_series_transformer import ( TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TimeSeriesTransformerForPrediction, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 23cb12a21a6f..b133943f01e5 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -1785,19 +1785,6 @@ def _segment_reduce(values, index, segment_reduce_fn, name): # changed "view" by "reshape" in the following line flat_values = values.reshape(flattened_shape.tolist()) - src = flat_values - index = flat_index.indices.long() - dim = 0 - dim_size = int(flat_index.num_segments) - - if segment_reduce_fn == "sum": - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - out = torch.zeros(size, dtype=src.dtype, device=src.device) - elif segment_reduce_fn: - pass - out = torch.zeros(int(flat_index.num_segments), dtype=flat_values.dtype) segment_means = out.scatter_reduce( dim=0, index=flat_index.indices.long(), src=flat_values, reduce=segment_reduce_fn, include_self=False diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index cb2f93be0fc9..af81387370d8 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4911,6 +4911,48 @@ def load_tf_weights_in_t5(*args, **kwargs): requires_backends(load_tf_weights_in_t5, ["torch"]) +TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class TapasForMaskedLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TapasForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TapasForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TapasModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class TapasPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +def load_tf_weights_in_tapas(*args, **kwargs): + requires_backends(load_tf_weights_in_tapas, ["torch"]) + + TIME_SERIES_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None From 75047a3b78c7760c81a9c07f3034011c7e8ae711 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 7 Nov 2022 17:17:39 +0100 Subject: [PATCH 3/9] Add require_torch --- tests/models/tapas/test_modeling_tapas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index d6726f44490f..c6a7bdd7e99a 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -899,6 +899,7 @@ def test_inference_classification_head(self): # Below: tests for Tapas utilities which are defined in modeling_tapas.py. # These are based on segmented_tensor_test.py of the original implementation. # URL: https://github.com/google-research/tapas/blob/master/tapas/models/segmented_tensor_test.py +@require_torch class TapasUtilitiesTest(unittest.TestCase): def _prepare_tables(self): """Prepares two tables, both with three distinct rows. From e407d4bd157964d11f58901f6b44a0c7be07fd5f Mon Sep 17 00:00:00 2001 From: bartekkz Date: Wed, 9 Nov 2022 22:39:27 +0100 Subject: [PATCH 4/9] update vectorized sum test, add clone call --- src/transformers/models/tapas/modeling_tapas.py | 5 +++-- tests/models/tapas/test_modeling_tapas.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index b133943f01e5..45c4017390f7 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -1800,7 +1800,8 @@ def _segment_reduce(values, index, segment_reduce_fn, name): dim=0, ) - output_values = segment_means.view(new_shape.tolist()) + output_values = segment_means.clone().view(new_shape.tolist()) + # output_values = segment_means.view(new_shape.tolist()) output_index = range_index_map(index.batch_shape(), index.num_segments) return output_values, output_index @@ -2348,7 +2349,7 @@ def _calculate_expected_result( # PyTorch does not currently support Huber loss with custom delta so we define it ourself def huber_loss(input, target, delta: float = 1.0): errors = torch.abs(input - target) # shape (batch_size,) - return torch.where(errors < delta, 0.5 * errors**2, errors * delta - (0.5 * delta**2)) + return torch.where(errors < delta, 0.5 * errors ** 2, errors * delta - (0.5 * delta ** 2)) def _calculate_regression_loss( diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index c6a7bdd7e99a..61c5972b19f1 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -1056,11 +1056,11 @@ def test_reduce_max(self): def test_reduce_sum_vectorized(self): values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]) - index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0) + index = IndexMap(indices=torch.as_tensor([[0, 0, 1]]), num_segments=2, batch_dims=0) sums, new_index = reduce_sum(values, index) # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose - np.testing.assert_allclose(sums.numpy(), [[3.0, 5.0, 7.0], [3.0, 4.0, 5.0]]) + np.testing.assert_allclose(sums.numpy(), [3., 3,]) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(new_index.indices.numpy(), [0, 1]) np.testing.assert_array_equal(new_index.num_segments.numpy(), 2) From e842fbcdd6c058a1d70c1e387e14142156c36f35 Mon Sep 17 00:00:00 2001 From: bartekkz Date: Wed, 9 Nov 2022 22:41:09 +0100 Subject: [PATCH 5/9] remove artifacts --- src/transformers/models/tapas/modeling_tapas.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 45c4017390f7..14480c076f21 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -1801,7 +1801,6 @@ def _segment_reduce(values, index, segment_reduce_fn, name): ) output_values = segment_means.clone().view(new_shape.tolist()) - # output_values = segment_means.view(new_shape.tolist()) output_index = range_index_map(index.batch_shape(), index.num_segments) return output_values, output_index From 47b1f9455a703ddce6bfbb466731f22238ed1120 Mon Sep 17 00:00:00 2001 From: bartekkz Date: Wed, 9 Nov 2022 23:24:10 +0100 Subject: [PATCH 6/9] fix style --- src/transformers/models/tapas/modeling_tapas.py | 2 +- tests/models/tapas/test_modeling_tapas.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 14480c076f21..5989624ace18 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -2348,7 +2348,7 @@ def _calculate_expected_result( # PyTorch does not currently support Huber loss with custom delta so we define it ourself def huber_loss(input, target, delta: float = 1.0): errors = torch.abs(input - target) # shape (batch_size,) - return torch.where(errors < delta, 0.5 * errors ** 2, errors * delta - (0.5 * delta ** 2)) + return torch.where(errors < delta, 0.5 * errors**2, errors * delta - (0.5 * delta**2)) def _calculate_regression_loss( diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index 61c5972b19f1..715c675e6d4c 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -1060,7 +1060,7 @@ def test_reduce_sum_vectorized(self): sums, new_index = reduce_sum(values, index) # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose - np.testing.assert_allclose(sums.numpy(), [3., 3,]) + np.testing.assert_allclose(sums.numpy(), [3., 3.]) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(new_index.indices.numpy(), [0, 1]) np.testing.assert_array_equal(new_index.num_segments.numpy(), 2) From 7a7b68d73e773a0c8b6cfaa0efc002de656664b3 Mon Sep 17 00:00:00 2001 From: bartekkz Date: Thu, 10 Nov 2022 00:43:38 +0100 Subject: [PATCH 7/9] fix style v2 --- tests/models/tapas/test_modeling_tapas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/tapas/test_modeling_tapas.py b/tests/models/tapas/test_modeling_tapas.py index 715c675e6d4c..504f3e278ea8 100644 --- a/tests/models/tapas/test_modeling_tapas.py +++ b/tests/models/tapas/test_modeling_tapas.py @@ -1060,7 +1060,7 @@ def test_reduce_sum_vectorized(self): sums, new_index = reduce_sum(values, index) # We use np.testing.assert_allclose rather than Tensorflow's assertAllClose - np.testing.assert_allclose(sums.numpy(), [3., 3.]) + np.testing.assert_allclose(sums.numpy(), [3.0, 3.0]) # We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual np.testing.assert_array_equal(new_index.indices.numpy(), [0, 1]) np.testing.assert_array_equal(new_index.num_segments.numpy(), 2) From bd074cfa044d5bd7512aa9d4763495d672648f94 Mon Sep 17 00:00:00 2001 From: bartekkz Date: Fri, 11 Nov 2022 13:56:42 +0100 Subject: [PATCH 8/9] remove "scatter" mentions from the code base --- docker/transformers-all-latest-gpu/Dockerfile | 3 -- docker/transformers-doc-builder/Dockerfile | 1 - docker/transformers-past-gpu/Dockerfile | 6 --- docs/source/en/model_doc/tapas.mdx | 3 +- src/transformers/__init__.py | 1 - src/transformers/file_utils.py | 1 - .../models/tapas/modeling_tapas.py | 1 - src/transformers/testing_utils.py | 19 -------- src/transformers/utils/__init__.py | 1 - .../utils/dummy_scatter_objects.py | 45 ------------------- src/transformers/utils/import_utils.py | 22 +-------- tests/deepspeed/test_model_zoo.py | 3 +- tests/models/auto/test_modeling_auto.py | 2 - .../layoutxlm/test_tokenization_layoutxlm.py | 2 - tests/models/tapas/test_tokenization_tapas.py | 2 - tests/pipelines/test_pipelines_common.py | 2 - ...test_pipelines_table_question_answering.py | 5 --- 17 files changed, 4 insertions(+), 115 deletions(-) delete mode 100644 src/transformers/utils/dummy_scatter_objects.py diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 10ee71890acc..711ecf7aeb3b 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -41,9 +41,6 @@ RUN python3 -m pip uninstall -y flax jax # TODO: remove this line once the conflict is resolved in these libraries. RUN python3 -m pip install --no-cache-dir git+https://github.com/onnx/tensorflow-onnx.git@ddca3a5eb2d912f20fe7e0568dd1a3013aee9fa3 -# Use installed torch version for `torch-scatter` to avid to deal with PYTORCH='pre'. -# If torch is nightly version, the link is likely to be invalid, but the installation falls back to the latest torch-scatter -RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+$CUDA.html RUN python3 -m pip install --no-cache-dir intel_extension_for_pytorch==$INTEL_TORCH_EXT+cpu -f https://software.intel.com/ipex-whl-stable RUN python3 -m pip install --no-cache-dir git+https://github.com/facebookresearch/detectron2.git pytesseract diff --git a/docker/transformers-doc-builder/Dockerfile b/docker/transformers-doc-builder/Dockerfile index c693f2843cde..0e5b072d4889 100644 --- a/docker/transformers-doc-builder/Dockerfile +++ b/docker/transformers-doc-builder/Dockerfile @@ -10,7 +10,6 @@ RUN apt-get -y update && apt-get install -y libsndfile1-dev && apt install -y te # Torch needs to be installed before deepspeed RUN python3 -m pip install --no-cache-dir ./transformers[deepspeed] -RUN python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python -c "from torch import version; print(version.__version__.split('+')[0])")+cpu.html RUN python3 -m pip install --no-cache-dir torchvision git+https://github.com/facebookresearch/detectron2.git pytesseract RUN python3 -m pip install --no-cache-dir pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com RUN python3 -m pip install -U "itsdangerous<2.1.0" diff --git a/docker/transformers-past-gpu/Dockerfile b/docker/transformers-past-gpu/Dockerfile index 826a8f12c2e1..99fb550c6a35 100644 --- a/docker/transformers-past-gpu/Dockerfile +++ b/docker/transformers-past-gpu/Dockerfile @@ -34,10 +34,4 @@ RUN python3 ./transformers/utils/past_ci_versions.py --framework $FRAMEWORK --ve RUN echo "INSTALL_CMD = $INSTALL_CMD" RUN $INSTALL_CMD -# Having installation problems for torch-scatter with torch <= 1.6. Disable so we have the same set of tests. -# (This part will be removed once the logic of using `past_ci_versions.py` is used in other Dockerfile files.) -# # Use installed torch version for `torch-scatter`. -# # (The env. variable $CUDA is defined in `past_ci_versions.py`) -# RUN [ "$FRAMEWORK" = "pytorch" ] && python3 -m pip install --no-cache-dir torch-scatter -f https://data.pyg.org/whl/torch-$(python3 -c "from torch import version; print(version.__version__.split('+')[0])")+$CUDA.html || echo "torch-scatter not to be installed" - RUN python3 -m pip install -U "itsdangerous<2.1.0" diff --git a/docs/source/en/model_doc/tapas.mdx b/docs/source/en/model_doc/tapas.mdx index 172800004fbb..5a2b54e8c32c 100644 --- a/docs/source/en/model_doc/tapas.mdx +++ b/docs/source/en/model_doc/tapas.mdx @@ -69,8 +69,7 @@ To summarize: -Initializing a model with a pre-trained base and randomly initialized classification heads from the hub can be done as shown below. Be sure to have installed the -[torch-scatter](https://github.com/rusty1s/pytorch_scatter) dependency: +Initializing a model with a pre-trained base and randomly initialized classification heads from the hub can be done as shown below. ```py >>> from transformers import TapasConfig, TapasForQuestionAnswering diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 59762e678098..84b0c7a9c371 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -32,7 +32,6 @@ OptionalDependencyNotAvailable, _LazyModule, is_flax_available, - is_scatter_available, is_sentencepiece_available, is_speech_available, is_tensorflow_text_available, diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 87cd9a469187..23219a328b6b 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -102,7 +102,6 @@ is_rjieba_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, - is_scatter_available, is_scipy_available, is_sentencepiece_available, is_sklearn_available, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 5989624ace18..a15400eea836 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -23,7 +23,6 @@ import torch -# torch.autograd.set_detect_anomaly(True) import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 7bbecd332072..eb69e7d241a6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -65,7 +65,6 @@ is_pytorch_quantization_available, is_rjieba_available, is_safetensors_available, - is_scatter_available, is_scipy_available, is_sentencepiece_available, is_soundfile_availble, @@ -319,16 +318,6 @@ def require_intel_extension_for_pytorch(test_case): )(test_case) -def require_torch_scatter(test_case): - """ - Decorator marking a test that requires PyTorch scatter. - - These tests are skipped when PyTorch scatter isn't installed. - - """ - return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case) - - def require_tensorflow_probability(test_case): """ Decorator marking a test that requires TensorFlow probability. @@ -405,14 +394,6 @@ def require_pytesseract(test_case): return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case) -def require_scatter(test_case): - """ - Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't - installed. - """ - return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case) - - def require_pytorch_quantization(test_case): """ Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7ea8cc558510..f6a5b8d49499 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -127,7 +127,6 @@ is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, - is_scatter_available, is_scipy_available, is_sentencepiece_available, is_sklearn_available, diff --git a/src/transformers/utils/dummy_scatter_objects.py b/src/transformers/utils/dummy_scatter_objects.py deleted file mode 100644 index 3f25018b5372..000000000000 --- a/src/transformers/utils/dummy_scatter_objects.py +++ /dev/null @@ -1,45 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa -from ..utils import DummyObject, requires_backends - - -TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None - - -class TapasForMaskedLM(metaclass=DummyObject): - _backends = ["scatter"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["scatter"]) - - -class TapasForQuestionAnswering(metaclass=DummyObject): - _backends = ["scatter"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["scatter"]) - - -class TapasForSequenceClassification(metaclass=DummyObject): - _backends = ["scatter"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["scatter"]) - - -class TapasModel(metaclass=DummyObject): - _backends = ["scatter"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["scatter"]) - - -class TapasPreTrainedModel(metaclass=DummyObject): - _backends = ["scatter"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["scatter"]) - - -def load_tf_weights_in_tapas(*args, **kwargs): - requires_backends(load_tf_weights_in_tapas, ["scatter"]) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 9d7ffc42972f..58b96f962209 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -145,6 +145,7 @@ except importlib_metadata.PackageNotFoundError: _faiss_available = False + _ftfy_available = importlib.util.find_spec("ftfy") is not None try: _ftfy_version = importlib_metadata.version("ftfy") @@ -176,6 +177,7 @@ except importlib_metadata.PackageNotFoundError: _tf2onnx_available = False + _onnx_available = importlib.util.find_spec("onnxruntime") is not None try: _onxx_version = importlib_metadata.version("onnx") @@ -184,14 +186,6 @@ _onnx_available = False -_scatter_available = importlib.util.find_spec("torch_scatter") is not None -try: - _scatter_version = importlib_metadata.version("torch_scatter") - logger.debug(f"Successfully imported torch-scatter version {_scatter_version}") -except importlib_metadata.PackageNotFoundError: - _scatter_available = False - - _pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None try: _pytorch_quantization_version = importlib_metadata.version("pytorch_quantization") @@ -584,10 +578,6 @@ def is_in_notebook(): return False -def is_scatter_available(): - return _scatter_available - - def is_pytorch_quantization_available(): return _pytorch_quantization_available @@ -826,13 +816,6 @@ def is_jumanpp_available(): that match your environment. Please note that you may need to restart your runtime after installation. """ - -# docstyle-ignore -SCATTER_IMPORT_ERROR = """ -{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as -explained here: https://github.com/rusty1s/pytorch_scatter. Please note that you may need to restart your runtime after installation. -""" - # docstyle-ignore PYTORCH_QUANTIZATION_IMPORT_ERROR = """ {0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip: @@ -941,7 +924,6 @@ def is_jumanpp_available(): ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)), - ("scatter", (is_scatter_available, SCATTER_IMPORT_ERROR)), ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)), diff --git a/tests/deepspeed/test_model_zoo.py b/tests/deepspeed/test_model_zoo.py index ac33b7f5a279..cd2c6b9e254f 100644 --- a/tests/deepspeed/test_model_zoo.py +++ b/tests/deepspeed/test_model_zoo.py @@ -130,8 +130,7 @@ # models with low usage, unstable API, things about to change - do nothing about the following until someone runs into a problem TAPAS_TINY = "hf-internal-testing/tiny-random-tapas" # additional notes on tapas -# 1. requires torch_scatter - skip if it's not installed? -# 2. "Table must be of type pd.DataFrame" failure +# 1. "Table must be of type pd.DataFrame" failure # TODO: new models to add: diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 95df9365c655..9745ecd94b74 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -28,7 +28,6 @@ DUMMY_UNKNOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, RequestCounter, - require_scatter, require_torch, slow, ) @@ -199,7 +198,6 @@ def test_question_answering_model_from_pretrained(self): self.assertIsInstance(model, BertForQuestionAnswering) @slow - @require_scatter def test_table_question_answering_model_from_pretrained(self): for model_name in TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST[5:6]: config = AutoConfig.from_pretrained(model_name) diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index ffc58fe0abee..e74dfe496c1c 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -32,7 +32,6 @@ get_tests_dir, is_pt_tf_cross_test, require_pandas, - require_scatter, require_sentencepiece, require_tokenizers, require_torch, @@ -1176,7 +1175,6 @@ def test_offsets_mapping(self): @require_torch @slow - @require_scatter def test_torch_encode_plus_sent_to_model(self): import torch diff --git a/tests/models/tapas/test_tokenization_tapas.py b/tests/models/tapas/test_tokenization_tapas.py index ff873c76cd7d..89865a78e733 100644 --- a/tests/models/tapas/test_tokenization_tapas.py +++ b/tests/models/tapas/test_tokenization_tapas.py @@ -35,7 +35,6 @@ from transformers.testing_utils import ( is_pt_tf_cross_test, require_pandas, - require_scatter, require_tensorflow_probability, require_tokenizers, require_torch, @@ -1031,7 +1030,6 @@ def test_token_type_ids(self): @require_torch @slow - @require_scatter def test_torch_encode_plus_sent_to_model(self): import torch diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 626c7c4d31bc..492a63a4ccd0 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -50,7 +50,6 @@ RequestCounter, is_staging_test, nested_simplify, - require_scatter, require_tensorflow_probability, require_tf, require_torch, @@ -749,7 +748,6 @@ def test_load_default_pipelines_tf(self): @slow @require_torch - @require_scatter def test_load_default_pipelines_pt_table_qa(self): import torch diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py index 089186a4672c..390b764abc83 100644 --- a/tests/pipelines/test_pipelines_table_question_answering.py +++ b/tests/pipelines/test_pipelines_table_question_answering.py @@ -27,7 +27,6 @@ require_tensorflow_probability, require_tf, require_torch, - require_torch_scatter, slow, ) @@ -145,7 +144,6 @@ def test_small_model_tf(self): ) @require_torch - @require_torch_scatter def test_small_model_pt(self): model_id = "lysandre/tiny-tapas-random-wtq" model = AutoModelForTableQuestionAnswering.from_pretrained(model_id) @@ -248,7 +246,6 @@ def test_small_model_pt(self): ) @require_torch - @require_torch_scatter def test_slow_tokenizer_sqa_pt(self): model_id = "lysandre/tiny-tapas-random-sqa" model = AutoModelForTableQuestionAnswering.from_pretrained(model_id) @@ -490,7 +487,6 @@ def test_slow_tokenizer_sqa_tf(self): ) @slow - @require_torch_scatter def test_integration_wtq_pt(self): table_querier = pipeline("table-question-answering") @@ -584,7 +580,6 @@ def test_integration_wtq_tf(self): self.assertListEqual(results, expected_results) @slow - @require_torch_scatter def test_integration_sqa_pt(self): table_querier = pipeline( "table-question-answering", From 85df5d7aff8e53ee1d314f9b1a004e0b29ecad39 Mon Sep 17 00:00:00 2001 From: bartekkz Date: Fri, 11 Nov 2022 14:52:08 +0100 Subject: [PATCH 9/9] fix isort error --- src/transformers/models/tapas/modeling_tapas.py | 1 - .../pipelines/test_pipelines_table_question_answering.py | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index a15400eea836..706d676303be 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -22,7 +22,6 @@ from typing import Optional, Tuple import torch - import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss diff --git a/tests/pipelines/test_pipelines_table_question_answering.py b/tests/pipelines/test_pipelines_table_question_answering.py index 390b764abc83..9e2e7e531753 100644 --- a/tests/pipelines/test_pipelines_table_question_answering.py +++ b/tests/pipelines/test_pipelines_table_question_answering.py @@ -22,13 +22,7 @@ TFAutoModelForTableQuestionAnswering, pipeline, ) -from transformers.testing_utils import ( - require_pandas, - require_tensorflow_probability, - require_tf, - require_torch, - slow, -) +from transformers.testing_utils import require_pandas, require_tensorflow_probability, require_tf, require_torch, slow from .test_pipelines_common import PipelineTestCaseMeta