diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 89a1fc8710..3948802caf 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -36,8 +36,8 @@ def main(): ) parser.add_argument( "--task", - default="default", - help="The type of tasks to export the model with.", + default="auto", + help="The type of task to export the model with.", ) parser.add_argument("--opset", type=int, default=None, help="ONNX opset version to export the model with.") parser.add_argument( @@ -73,20 +73,25 @@ def main(): if not args.output.parent.exists(): args.output.parent.mkdir(parents=True) + # Infer the task + task = args.task + if task == "auto": + task = TasksManager.infer_task_from_model(args.model) + # Allocate the model - model = TasksManager.get_model_from_task(args.task, args.model, framework=args.framework, cache_dir=args.cache_dir) + model = TasksManager.get_model_from_task(task, args.model, framework=args.framework, cache_dir=args.cache_dir) model_type = model.config.model_type.replace("_", "-") model_name = getattr(model, "name", None) onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model_type, "onnx", task=args.task, model_name=model_name + model_type, "onnx", task=task, model_name=model_name ) onnx_config = onnx_config_constructor(model.config) needs_pad_token_id = ( isinstance(onnx_config, OnnxConfigWithPast) and getattr(model.config, "pad_token_id", None) is None - and args.task in ["sequence_classification"] + and task in ["sequence_classification"] ) if needs_pad_token_id: if args.pad_token_id is not None: @@ -120,9 +125,13 @@ def main(): if args.atol is None: args.atol = onnx_config.ATOL_FOR_VALIDATION if isinstance(args.atol, dict): - args.atol = args.atol[args.task.replace("-with-past", "")] + args.atol = args.atol[task.replace("-with-past", "")] - validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) + try: + validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) + except ValueError: + logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}") + return logger.info(f"All good, model saved at: {args.output.as_posix()}") diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 840a907cd6..6c92576c2d 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -324,7 +324,8 @@ def validate_model_outputs( f"Difference: {onnx_outputs_set.difference(ref_outputs_set)}" ) else: - logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})") + onnx_output_names = ", ".join(onnx_outputs_set) + logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_output_names})") # Check the shape and values match shape_failures = [] @@ -352,9 +353,9 @@ def validate_model_outputs( logger.info(f"\t\t-[✓] all values close (atol: {atol})") if shape_failures: - msg = "\n\t".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (ONNX)" for t in shape_failures) - raise ValueError("Output shapes do not match between reference model and ONNX exported model:\n" f"{msg}") + msg = "\n".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (ONNX)" for t in shape_failures) + raise ValueError(f"Output shapes do not match between reference model and ONNX exported model:\n{msg}") if value_failures: - msg = "\n\t".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures) - raise ValueError("Output values do not match between reference model and ONNX exported model:\n" f"{msg}") + msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures) + raise ValueError(f"Output values do not match between reference model and ONNX exported model:\n{msg}") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 12e2969812..e79ce64999 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -22,6 +22,8 @@ from transformers import PretrainedConfig, is_tf_available, is_torch_available from transformers.utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging +import huggingface_hub + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -568,7 +570,7 @@ def get_supported_tasks_for_model_type( return TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter] @staticmethod - def task_to_task(task: str) -> str: + def format_task(task: str) -> str: return task.replace("-with-past", "") @staticmethod @@ -598,7 +600,7 @@ def get_model_class_for_task(task: str, framework: str = "pt") -> Type: Returns: The AutoModel class corresponding to the task. """ - task = TasksManager.task_to_task(task) + task = TasksManager.format_task(task) TasksManager._validate_framework_choice(framework) if framework == "pt": task_to_automodel = TasksManager._TASKS_TO_AUTOMODELS @@ -659,6 +661,50 @@ def determine_framework(model: str, framework: str = None) -> str: return framework + @staticmethod + def infer_task_from_model(model_name_or_path: str) -> str: + """ + Infers the task from the model repo. + + Args: + model_name_or_path (`str`): + The model repo or local path (not supported for now). + + Returns: + `str`: The task name automatically detected from the model repo. + """ + + tasks_to_automodels = {} + class_name_prefix = "" + if is_torch_available(): + tasks_to_automodels = TasksManager._TASKS_TO_AUTOMODELS + else: + tasks_to_automodels = TasksManager._TASKS_TO_TF_AUTOMODELS + class_name_prefix = "TF" + + inferred_task_name = None + is_local = os.path.isdir(model_name_or_path) + + if is_local: + # TODO: maybe implement that. + raise RuntimeError("Cannot infer the task from a local directory yet, please specify the task manually.") + else: + model_info = huggingface_hub.model_info(model_name_or_path) + transformers_info = model_info.transformersInfo + if transformers_info is None or transformers_info.get("auto_model") is None: + raise RuntimeError(f"Could not infer the task from the model repo {model_name_or_path}") + auto_model_class_name = transformers_info["auto_model"] + if not auto_model_class_name.startswith("TF"): + auto_model_class_name = f"{class_name_prefix}{auto_model_class_name}" + for task_name, class_ in tasks_to_automodels.items(): + if class_.__name__ == auto_model_class_name: + inferred_task_name = task_name + break + if inferred_task_name is None: + raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.") + logger.info(f"Automatic task detection to {inferred_task_name}.") + return inferred_task_name + @staticmethod def get_model_from_task( task: str, model: str, framework: str = None, cache_dir: str = None @@ -682,6 +728,8 @@ def get_model_from_task( """ framework = TasksManager.determine_framework(model, framework) + if task == "auto": + task = TasksManager.infer_task_from_model(model) model_class = TasksManager.get_model_class_for_task(task, framework) try: model = model_class.from_pretrained(model, cache_dir=cache_dir)