From 0cc4cc68bcaf96d150c3f200d8fcdfc3c05d0736 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 26 Sep 2024 10:51:05 +0000 Subject: [PATCH 1/3] add recursive check and test warnings --- src/transformers/modeling_utils.py | 6 +++++- tests/utils/test_modeling_utils.py | 32 +++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e3d789087d2..30cffca4df87 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1639,12 +1639,16 @@ def can_generate(cls) -> bool: Returns: `bool`: Whether this model can generate sequences with `.generate()`. """ - # Directly inherits `GenerationMixin` -> can generate + # Directly inherits `GenerationMixin`-> can generate if "GenerationMixin" in str(cls.__bases__): return True # Model class overwrites `generate` (e.g. time series models) -> can generate if str(cls.__name__) in str(cls.generate): return True + # The class inherits from a class that can generate (recursive check) -> can generate + for base in cls.__bases__: + if "PreTrainedModel" not in str(base) and base.can_generate(): + return True # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this # was how we detected whether a model could generate. if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 5155647059f1..3317a47d7560 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1718,29 +1718,51 @@ def test_isin_mps_friendly(self): def test_can_generate(self): """Tests the behavior of `PreTrainedModel.can_generate` method.""" + logger = logging.get_logger("transformers.modeling_utils") + logger.warning_once.cache_clear() + # 1 - By default, a model CAN'T generate - self.assertFalse(BertModel.can_generate()) + can_generate = BertModel.can_generate() + self.assertFalse(can_generate) # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly class DummyBertWithMixin(BertModel, GenerationMixin): pass - self.assertTrue(DummyBertWithMixin.can_generate()) + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithMixin.can_generate() + self.assertTrue("" == cl.out) + self.assertTrue(can_generate) # 3 - Alternatively, a model can implement a `generate` method class DummyBertWithGenerate(BertModel): def generate(self): pass - self.assertTrue(DummyBertWithGenerate.can_generate()) + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithGenerate.can_generate() + self.assertTrue("" == cl.out) + self.assertTrue(can_generate) + + # 4 - Finally, it can inherit from a model that can generate + class DummyBertWithParent(DummyBertWithMixin): + pass + + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithParent.can_generate() + self.assertTrue("" == cl.out) + self.assertTrue(can_generate) - # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited + # 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited # `GenerationMixin`) class DummyBertWithPrepareInputs(BertModel): def prepare_inputs_for_generation(self): pass - self.assertTrue(DummyBertWithPrepareInputs.can_generate()) + with CaptureLogger(logger) as cl: + can_generate = DummyBertWithPrepareInputs.can_generate() + self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out) + self.assertTrue(can_generate) def test_save_and_load_config_with_custom_generation(self): """ From 9ba46006f1f70ebe0259d51ca7798f808e14dbd4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 26 Sep 2024 10:55:17 +0000 Subject: [PATCH 2/3] missing space --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 30cffca4df87..4aab9c467a28 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1639,7 +1639,7 @@ def can_generate(cls) -> bool: Returns: `bool`: Whether this model can generate sequences with `.generate()`. """ - # Directly inherits `GenerationMixin`-> can generate + # Directly inherits `GenerationMixin` -> can generate if "GenerationMixin" in str(cls.__bases__): return True # Model class overwrites `generate` (e.g. time series models) -> can generate From 9029753a920237317390f2a46e7ff21a6065ef4c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 26 Sep 2024 11:12:10 +0000 Subject: [PATCH 3/3] models without can_generate --- src/transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4aab9c467a28..d0f4239c38cc 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1647,6 +1647,8 @@ def can_generate(cls) -> bool: return True # The class inherits from a class that can generate (recursive check) -> can generate for base in cls.__bases__: + if not hasattr(base, "can_generate"): + continue if "PreTrainedModel" not in str(base) and base.can_generate(): return True # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this