diff --git a/src/transformers/commands/pt_to_tf.py b/src/transformers/commands/pt_to_tf.py index 669f7a98003b..e025db5d1344 100644 --- a/src/transformers/commands/pt_to_tf.py +++ b/src/transformers/commands/pt_to_tf.py @@ -68,6 +68,7 @@ def convert_command_factory(args: Namespace): args.no_pr, args.push, args.extra_commit_description, + args.override_model_class, ) @@ -126,6 +127,13 @@ def register_subcommand(parser: ArgumentParser): default="", help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).", ) + train_parser.add_argument( + "--override-model-class", + type=str, + default=None, + help="If you think you know better than the auto-detector, you can specify the model class here. " + "Can be either an AutoModel class or a specific model class like BertForSequenceClassification.", + ) train_parser.set_defaults(func=convert_command_factory) @staticmethod @@ -175,6 +183,7 @@ def __init__( no_pr: bool, push: bool, extra_commit_description: str, + override_model_class: str, *args, ): self._logger = logging.get_logger("transformers-cli/pt_to_tf") @@ -185,6 +194,7 @@ def __init__( self._no_pr = no_pr self._push = push self._extra_commit_description = extra_commit_description + self._override_model_class = override_model_class def get_inputs(self, pt_model, config): """ @@ -269,7 +279,20 @@ def run(self): # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights config = AutoConfig.from_pretrained(self._local_dir) architectures = config.architectures - if architectures is None: # No architecture defined -- use auto classes + if self._override_model_class is not None: + if self._override_model_class.startswith("TF"): + architectures = [self._override_model_class[2:]] + else: + architectures = [self._override_model_class] + try: + pt_class = getattr(import_module("transformers"), architectures[0]) + except AttributeError: + raise ValueError(f"Model class {self._override_model_class} not found in transformers.") + try: + tf_class = getattr(import_module("transformers"), "TF" + architectures[0]) + except AttributeError: + raise ValueError(f"TF model class TF{self._override_model_class} not found in transformers.") + elif architectures is None: # No architecture defined -- use auto classes pt_class = getattr(import_module("transformers"), "AutoModel") tf_class = getattr(import_module("transformers"), "TFAutoModel") self._logger.warning("No detected architecture, using AutoModel/TFAutoModel") @@ -287,7 +310,6 @@ def run(self): pt_model = pt_class.from_pretrained(self._local_dir) pt_model.eval() - tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) pt_input, tf_input = self.get_inputs(pt_model, config) with torch.no_grad():