Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/transformers/commands/pt_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def convert_command_factory(args: Namespace):
args.no_pr,
args.push,
args.extra_commit_description,
args.override_model_class,
)


Expand Down Expand Up @@ -126,6 +127,13 @@ def register_subcommand(parser: ArgumentParser):
default="",
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
)
train_parser.add_argument(
"--override-model-class",
type=str,
default=None,
help="If you think you know better than the auto-detector, you can specify the model class here. "
"Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
)
train_parser.set_defaults(func=convert_command_factory)

@staticmethod
Expand Down Expand Up @@ -175,6 +183,7 @@ def __init__(
no_pr: bool,
push: bool,
extra_commit_description: str,
override_model_class: str,
*args,
):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
Expand All @@ -185,6 +194,7 @@ def __init__(
self._no_pr = no_pr
self._push = push
self._extra_commit_description = extra_commit_description
self._override_model_class = override_model_class

def get_inputs(self, pt_model, config):
"""
Expand Down Expand Up @@ -269,7 +279,20 @@ def run(self):
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir)
architectures = config.architectures
if architectures is None: # No architecture defined -- use auto classes
if self._override_model_class is not None:
if self._override_model_class.startswith("TF"):
architectures = [self._override_model_class[2:]]
else:
architectures = [self._override_model_class]
try:
pt_class = getattr(import_module("transformers"), architectures[0])
except AttributeError:
raise ValueError(f"Model class {self._override_model_class} not found in transformers.")
try:
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
except AttributeError:
raise ValueError(f"TF model class TF{self._override_model_class} not found in transformers.")
elif architectures is None: # No architecture defined -- use auto classes
pt_class = getattr(import_module("transformers"), "AutoModel")
tf_class = getattr(import_module("transformers"), "TFAutoModel")
self._logger.warning("No detected architecture, using AutoModel/TFAutoModel")
Expand All @@ -287,7 +310,6 @@ def run(self):
pt_model = pt_class.from_pretrained(self._local_dir)
pt_model.eval()

tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
pt_input, tf_input = self.get_inputs(pt_model, config)

with torch.no_grad():
Expand Down