Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,12 @@ def can_generate(cls) -> bool:
# Model class overwrites `generate` (e.g. time series models) -> can generate
if str(cls.__name__) in str(cls.generate):
return True
# The class inherits from a class that can generate (recursive check) -> can generate
for base in cls.__bases__:
if not hasattr(base, "can_generate"):
Comment on lines +1649 to +1650
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this only grandfather check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's recursive! (it calls can_generate of the parent, which in turn may call the granparent's, and so on :D )

continue
if "PreTrainedModel" not in str(base) and base.can_generate():
return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
Expand Down
32 changes: 27 additions & 5 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,29 +1718,51 @@ def test_isin_mps_friendly(self):

def test_can_generate(self):
"""Tests the behavior of `PreTrainedModel.can_generate` method."""
logger = logging.get_logger("transformers.modeling_utils")
logger.warning_once.cache_clear()

# 1 - By default, a model CAN'T generate
self.assertFalse(BertModel.can_generate())
can_generate = BertModel.can_generate()
self.assertFalse(can_generate)

# 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly
class DummyBertWithMixin(BertModel, GenerationMixin):
pass

self.assertTrue(DummyBertWithMixin.can_generate())
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithMixin.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)

# 3 - Alternatively, a model can implement a `generate` method
class DummyBertWithGenerate(BertModel):
def generate(self):
pass

self.assertTrue(DummyBertWithGenerate.can_generate())
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithGenerate.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)

# 4 - Finally, it can inherit from a model that can generate
class DummyBertWithParent(DummyBertWithMixin):
pass

with CaptureLogger(logger) as cl:
can_generate = DummyBertWithParent.can_generate()
self.assertTrue("" == cl.out)
self.assertTrue(can_generate)

# 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# 5 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited
# `GenerationMixin`)
class DummyBertWithPrepareInputs(BertModel):
def prepare_inputs_for_generation(self):
pass

self.assertTrue(DummyBertWithPrepareInputs.can_generate())
with CaptureLogger(logger) as cl:
can_generate = DummyBertWithPrepareInputs.can_generate()
self.assertTrue("it doesn't directly inherit from `GenerationMixin`" in cl.out)
self.assertTrue(can_generate)

def test_save_and_load_config_with_custom_generation(self):
"""
Expand Down