From ad923d872f73e39611309f8fa12ca6bc9e6ffa15 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Nov 2022 13:44:33 +0100 Subject: [PATCH 1/4] Automatic task detection using the HuggingFace Hub --- optimum/exporters/onnx/__main__.py | 15 ++++++++---- optimum/exporters/tasks.py | 38 ++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 89a1fc8710..b26794c629 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -36,7 +36,7 @@ def main(): ) parser.add_argument( "--task", - default="default", + default="auto", help="The type of tasks to export the model with.", ) parser.add_argument("--opset", type=int, default=None, help="ONNX opset version to export the model with.") @@ -73,20 +73,25 @@ def main(): if not args.output.parent.exists(): args.output.parent.mkdir(parents=True) + # Infer the task + task = args.task + if task == "auto": + task = TasksManager.infer_task_from_model_info(args.model) + # Allocate the model - model = TasksManager.get_model_from_task(args.task, args.model, framework=args.framework, cache_dir=args.cache_dir) + model = TasksManager.get_model_from_task(task, args.model, framework=args.framework, cache_dir=args.cache_dir) model_type = model.config.model_type.replace("_", "-") model_name = getattr(model, "name", None) onnx_config_constructor = TasksManager.get_exporter_config_constructor( - model_type, "onnx", task=args.task, model_name=model_name + model_type, "onnx", task=task, model_name=model_name ) onnx_config = onnx_config_constructor(model.config) needs_pad_token_id = ( isinstance(onnx_config, OnnxConfigWithPast) and getattr(model.config, "pad_token_id", None) is None - and args.task in ["sequence_classification"] + and task in ["sequence_classification"] ) if needs_pad_token_id: if args.pad_token_id is not None: @@ -120,7 +125,7 @@ def main(): if args.atol is None: args.atol = onnx_config.ATOL_FOR_VALIDATION if isinstance(args.atol, dict): - args.atol = args.atol[args.task.replace("-with-past", "")] + args.atol = args.atol[task.replace("-with-past", "")] validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) logger.info(f"All good, model saved at: {args.output.as_posix()}") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 12e2969812..1586638b2f 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -22,6 +22,8 @@ from transformers import PretrainedConfig, is_tf_available, is_torch_available from transformers.utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging +import huggingface_hub + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -568,7 +570,7 @@ def get_supported_tasks_for_model_type( return TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter] @staticmethod - def task_to_task(task: str) -> str: + def format_task(task: str) -> str: return task.replace("-with-past", "") @staticmethod @@ -598,7 +600,7 @@ def get_model_class_for_task(task: str, framework: str = "pt") -> Type: Returns: The AutoModel class corresponding to the task. """ - task = TasksManager.task_to_task(task) + task = TasksManager.format_task(task) TasksManager._validate_framework_choice(framework) if framework == "pt": task_to_automodel = TasksManager._TASKS_TO_AUTOMODELS @@ -659,6 +661,36 @@ def determine_framework(model: str, framework: str = None) -> str: return framework + @staticmethod + def infer_task_from_model_info(model_name_or_path): + tasks_to_automodels = {} + class_name_prefix = "" + if is_torch_available(): + tasks_to_automodels = TasksManager._TASKS_TO_AUTOMODELS + else: + tasks_to_automodels = TasksManager._TASKS_TO_TF_AUTOMODELS + class_name_prefix = "TF" + + inferred_task_name = None + is_local = os.path.isdir(model_name_or_path) + + if is_local: + # TODO: implement this. + raise NotImplementedError("Cannot infer the task from a local directory yet.") + else: + model_info = huggingface_hub.model_info(model_name_or_path) + transformers_info = model_info.transformersInfo + if transformers_info is None or transformers_info.get("auto_model") is None: + raise RuntimeError(f"Could not infer the task from the model repo {model_name_or_path}") + auto_model_class_name = f"{class_name_prefix}{transformers_info['auto_model']}" + for task_name, class_ in tasks_to_automodels.items(): + if class_.__name__ == auto_model_class_name: + inferred_task_name = task_name + break + if inferred_task_name is None: + raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.") + return inferred_task_name + @staticmethod def get_model_from_task( task: str, model: str, framework: str = None, cache_dir: str = None @@ -682,6 +714,8 @@ def get_model_from_task( """ framework = TasksManager.determine_framework(model, framework) + if task == "auto": + task = TasksManager.infer_task_from_model_info(model) model_class = TasksManager.get_model_class_for_task(task, framework) try: model = model_class.from_pretrained(model, cache_dir=cache_dir) From c58e6f7cf4e2ace331105d0c0c97bc6d8055390e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Nov 2022 13:45:53 +0100 Subject: [PATCH 2/4] Add logging --- optimum/exporters/tasks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 1586638b2f..ae92ae1b1b 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -689,6 +689,7 @@ def infer_task_from_model_info(model_name_or_path): break if inferred_task_name is None: raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.") + logger.info(f"Automatic task detection to {inferred_task_name}") return inferred_task_name @staticmethod From 351a66b8aa85a4c43b1fa206c903e987765757bc Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Nov 2022 14:01:04 +0100 Subject: [PATCH 3/4] Fixed logging issues --- optimum/exporters/onnx/__main__.py | 8 ++++++-- optimum/exporters/onnx/convert.py | 11 ++++++----- optimum/exporters/tasks.py | 6 +++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index b26794c629..6df20bf038 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -76,7 +76,7 @@ def main(): # Infer the task task = args.task if task == "auto": - task = TasksManager.infer_task_from_model_info(args.model) + task = TasksManager.infer_task_from_model(args.model) # Allocate the model model = TasksManager.get_model_from_task(task, args.model, framework=args.framework, cache_dir=args.cache_dir) @@ -127,7 +127,11 @@ def main(): if isinstance(args.atol, dict): args.atol = args.atol[task.replace("-with-past", "")] - validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) + try: + validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol) + except ValueError: + logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}") + return logger.info(f"All good, model saved at: {args.output.as_posix()}") diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 840a907cd6..6c92576c2d 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -324,7 +324,8 @@ def validate_model_outputs( f"Difference: {onnx_outputs_set.difference(ref_outputs_set)}" ) else: - logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})") + onnx_output_names = ", ".join(onnx_outputs_set) + logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_output_names})") # Check the shape and values match shape_failures = [] @@ -352,9 +353,9 @@ def validate_model_outputs( logger.info(f"\t\t-[✓] all values close (atol: {atol})") if shape_failures: - msg = "\n\t".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (ONNX)" for t in shape_failures) - raise ValueError("Output shapes do not match between reference model and ONNX exported model:\n" f"{msg}") + msg = "\n".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (ONNX)" for t in shape_failures) + raise ValueError(f"Output shapes do not match between reference model and ONNX exported model:\n{msg}") if value_failures: - msg = "\n\t".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures) - raise ValueError("Output values do not match between reference model and ONNX exported model:\n" f"{msg}") + msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures) + raise ValueError(f"Output values do not match between reference model and ONNX exported model:\n{msg}") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index ae92ae1b1b..62c8a26f7e 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -662,7 +662,7 @@ def determine_framework(model: str, framework: str = None) -> str: return framework @staticmethod - def infer_task_from_model_info(model_name_or_path): + def infer_task_from_model(model_name_or_path): tasks_to_automodels = {} class_name_prefix = "" if is_torch_available(): @@ -689,7 +689,7 @@ def infer_task_from_model_info(model_name_or_path): break if inferred_task_name is None: raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.") - logger.info(f"Automatic task detection to {inferred_task_name}") + logger.info(f"Automatic task detection to {inferred_task_name}.") return inferred_task_name @staticmethod @@ -716,7 +716,7 @@ def get_model_from_task( """ framework = TasksManager.determine_framework(model, framework) if task == "auto": - task = TasksManager.infer_task_from_model_info(model) + task = TasksManager.infer_task_from_model(model) model_class = TasksManager.get_model_class_for_task(task, framework) try: model = model_class.from_pretrained(model, cache_dir=cache_dir) From ef3644660875b7860e95b850657a653ca968c68b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 3 Nov 2022 15:50:10 +0100 Subject: [PATCH 4/4] Apply suggestions --- optimum/exporters/onnx/__main__.py | 2 +- optimum/exporters/tasks.py | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 6df20bf038..3948802caf 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -37,7 +37,7 @@ def main(): parser.add_argument( "--task", default="auto", - help="The type of tasks to export the model with.", + help="The type of task to export the model with.", ) parser.add_argument("--opset", type=int, default=None, help="ONNX opset version to export the model with.") parser.add_argument( diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 62c8a26f7e..e79ce64999 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -662,7 +662,18 @@ def determine_framework(model: str, framework: str = None) -> str: return framework @staticmethod - def infer_task_from_model(model_name_or_path): + def infer_task_from_model(model_name_or_path: str) -> str: + """ + Infers the task from the model repo. + + Args: + model_name_or_path (`str`): + The model repo or local path (not supported for now). + + Returns: + `str`: The task name automatically detected from the model repo. + """ + tasks_to_automodels = {} class_name_prefix = "" if is_torch_available(): @@ -675,14 +686,16 @@ def infer_task_from_model(model_name_or_path): is_local = os.path.isdir(model_name_or_path) if is_local: - # TODO: implement this. - raise NotImplementedError("Cannot infer the task from a local directory yet.") + # TODO: maybe implement that. + raise RuntimeError("Cannot infer the task from a local directory yet, please specify the task manually.") else: model_info = huggingface_hub.model_info(model_name_or_path) transformers_info = model_info.transformersInfo if transformers_info is None or transformers_info.get("auto_model") is None: raise RuntimeError(f"Could not infer the task from the model repo {model_name_or_path}") - auto_model_class_name = f"{class_name_prefix}{transformers_info['auto_model']}" + auto_model_class_name = transformers_info["auto_model"] + if not auto_model_class_name.startswith("TF"): + auto_model_class_name = f"{class_name_prefix}{auto_model_class_name}" for task_name, class_ in tasks_to_automodels.items(): if class_.__name__ == auto_model_class_name: inferred_task_name = task_name