Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def main():
)
parser.add_argument(
"--task",
default="default",
help="The type of tasks to export the model with.",
default="auto",
help="The type of task to export the model with.",
)
parser.add_argument("--opset", type=int, default=None, help="ONNX opset version to export the model with.")
parser.add_argument(
Expand Down Expand Up @@ -73,20 +73,25 @@ def main():
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)

# Infer the task
task = args.task
if task == "auto":
task = TasksManager.infer_task_from_model(args.model)

# Allocate the model
model = TasksManager.get_model_from_task(args.task, args.model, framework=args.framework, cache_dir=args.cache_dir)
model = TasksManager.get_model_from_task(task, args.model, framework=args.framework, cache_dir=args.cache_dir)
model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", None)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model_type, "onnx", task=args.task, model_name=model_name
model_type, "onnx", task=task, model_name=model_name
)
onnx_config = onnx_config_constructor(model.config)

needs_pad_token_id = (
isinstance(onnx_config, OnnxConfigWithPast)
and getattr(model.config, "pad_token_id", None) is None
and args.task in ["sequence_classification"]
and task in ["sequence_classification"]
)
if needs_pad_token_id:
if args.pad_token_id is not None:
Expand Down Expand Up @@ -120,9 +125,13 @@ def main():
if args.atol is None:
args.atol = onnx_config.ATOL_FOR_VALIDATION
if isinstance(args.atol, dict):
args.atol = args.atol[args.task.replace("-with-past", "")]
args.atol = args.atol[task.replace("-with-past", "")]

validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
try:
validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
except ValueError:
logger.error(f"An error occured, but the model was saved at: {args.output.as_posix()}")
return
logger.info(f"All good, model saved at: {args.output.as_posix()}")


Expand Down
11 changes: 6 additions & 5 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def validate_model_outputs(
f"Difference: {onnx_outputs_set.difference(ref_outputs_set)}"
)
else:
logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})")
onnx_output_names = ", ".join(onnx_outputs_set)
logger.info(f"\t-[✓] ONNX model output names match reference model ({onnx_output_names})")

# Check the shape and values match
shape_failures = []
Expand Down Expand Up @@ -352,9 +353,9 @@ def validate_model_outputs(
logger.info(f"\t\t-[✓] all values close (atol: {atol})")

if shape_failures:
msg = "\n\t".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (ONNX)" for t in shape_failures)
raise ValueError("Output shapes do not match between reference model and ONNX exported model:\n" f"{msg}")
msg = "\n".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (ONNX)" for t in shape_failures)
raise ValueError(f"Output shapes do not match between reference model and ONNX exported model:\n{msg}")

if value_failures:
msg = "\n\t".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
raise ValueError("Output values do not match between reference model and ONNX exported model:\n" f"{msg}")
msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
raise ValueError(f"Output values do not match between reference model and ONNX exported model:\n{msg}")
52 changes: 50 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from transformers import PretrainedConfig, is_tf_available, is_torch_available
from transformers.utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging

import huggingface_hub


if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel
Expand Down Expand Up @@ -568,7 +570,7 @@ def get_supported_tasks_for_model_type(
return TasksManager._SUPPORTED_MODEL_TYPE[model_type][exporter]

@staticmethod
def task_to_task(task: str) -> str:
def format_task(task: str) -> str:
return task.replace("-with-past", "")

@staticmethod
Expand Down Expand Up @@ -598,7 +600,7 @@ def get_model_class_for_task(task: str, framework: str = "pt") -> Type:
Returns:
The AutoModel class corresponding to the task.
"""
task = TasksManager.task_to_task(task)
task = TasksManager.format_task(task)
TasksManager._validate_framework_choice(framework)
if framework == "pt":
task_to_automodel = TasksManager._TASKS_TO_AUTOMODELS
Expand Down Expand Up @@ -659,6 +661,50 @@ def determine_framework(model: str, framework: str = None) -> str:

return framework

@staticmethod
def infer_task_from_model(model_name_or_path: str) -> str:
"""
Infers the task from the model repo.

Args:
model_name_or_path (`str`):
The model repo or local path (not supported for now).

Returns:
`str`: The task name automatically detected from the model repo.
"""

tasks_to_automodels = {}
class_name_prefix = ""
if is_torch_available():
tasks_to_automodels = TasksManager._TASKS_TO_AUTOMODELS
else:
tasks_to_automodels = TasksManager._TASKS_TO_TF_AUTOMODELS
class_name_prefix = "TF"

inferred_task_name = None
is_local = os.path.isdir(model_name_or_path)

if is_local:
# TODO: maybe implement that.
raise RuntimeError("Cannot infer the task from a local directory yet, please specify the task manually.")
else:
model_info = huggingface_hub.model_info(model_name_or_path)
transformers_info = model_info.transformersInfo
if transformers_info is None or transformers_info.get("auto_model") is None:
raise RuntimeError(f"Could not infer the task from the model repo {model_name_or_path}")
auto_model_class_name = transformers_info["auto_model"]
if not auto_model_class_name.startswith("TF"):
auto_model_class_name = f"{class_name_prefix}{auto_model_class_name}"
for task_name, class_ in tasks_to_automodels.items():
if class_.__name__ == auto_model_class_name:
inferred_task_name = task_name
break
if inferred_task_name is None:
raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.")
logger.info(f"Automatic task detection to {inferred_task_name}.")
return inferred_task_name

@staticmethod
def get_model_from_task(
task: str, model: str, framework: str = None, cache_dir: str = None
Expand All @@ -682,6 +728,8 @@ def get_model_from_task(

"""
framework = TasksManager.determine_framework(model, framework)
if task == "auto":
task = TasksManager.infer_task_from_model(model)
model_class = TasksManager.get_model_class_for_task(task, framework)
try:
model = model_class.from_pretrained(model, cache_dir=cache_dir)
Expand Down