Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
is_torch_fx_proxy,
is_torchdynamo_compiling,
)
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.versions import require_version_core

Expand Down Expand Up @@ -3799,7 +3804,7 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
"""

# Skip the check during tracing.
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing():
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
return

if (attention_mask is not None) or (self.config.pad_token_id is None):
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,17 @@ def is_torch_compile_available():
return hasattr(torch, "compile")


def is_torchdynamo_compiling():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401

return dynamo.is_compiling()
except Exception:
return False


def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None:
return False
Expand Down
20 changes: 20 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.import_utils import is_torchdynamo_available


sys.path.append(str(Path(__file__).parent.parent / "utils"))
Expand Down Expand Up @@ -1014,6 +1015,25 @@ def test_warn_if_padding_and_no_attention_mask(self):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)

if not is_torchdynamo_available():
return
with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
logger.warning_once.cache_clear()
from torch._dynamo import config, testing

config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])

def f(input_ids):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)

compile_counter = testing.CompileCounter()
opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
opt_fn(input_ids)
self.assertEqual(compile_counter.frame_count, 0)

@require_torch_gpu
@slow
def test_pretrained_low_mem_new_config(self):
Expand Down