diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d35a480e2108..bbf229b5889b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -81,7 +81,12 @@ strtobool, ) from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files -from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy +from .utils.import_utils import ( + ENV_VARS_TRUE_VALUES, + is_sagemaker_mp_enabled, + is_torch_fx_proxy, + is_torchdynamo_compiling, +) from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod from .utils.versions import require_version_core @@ -3799,7 +3804,7 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): """ # Skip the check during tracing. - if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing(): + if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling(): return if (attention_mask is not None) or (self.config.pad_token_id is None): diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ae76a78ce217..6829ca9ad67e 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -463,6 +463,17 @@ def is_torch_compile_available(): return hasattr(torch, "compile") +def is_torchdynamo_compiling(): + if not is_torch_available(): + return False + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_compiling() + except Exception: + return False + + def is_torch_tensorrt_fx_available(): if importlib.util.find_spec("torch_tensorrt") is None: return False diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index bdadbe08005b..6d37f2247e8a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -55,6 +55,7 @@ WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) +from transformers.utils.import_utils import is_torchdynamo_available sys.path.append(str(Path(__file__).parent.parent / "utils")) @@ -1014,6 +1015,25 @@ def test_warn_if_padding_and_no_attention_mask(self): model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out) + if not is_torchdynamo_available(): + return + with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."): + logger.warning_once.cache_clear() + from torch._dynamo import config, testing + + config = PretrainedConfig() + config.pad_token_id = 0 + model = ModelWithHead(config) + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]]) + + def f(input_ids): + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + + compile_counter = testing.CompileCounter() + opt_fn = torch.compile(f, dynamic=True, backend=compile_counter) + opt_fn(input_ids) + self.assertEqual(compile_counter.frame_count, 0) + @require_torch_gpu @slow def test_pretrained_low_mem_new_config(self):