Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7731177
[WIP] ONNX export in optimum => Bert works
michaelbenayoun Sep 23, 2022
dd0ab63
[WIP] ONNX export in optimum => GPT-2 / GPT Neo
michaelbenayoun Sep 23, 2022
15e9b8c
[WIP] ONNX export in optimum => T5
michaelbenayoun Sep 26, 2022
0b896bf
[WIP] ONNX export in optimum
michaelbenayoun Sep 27, 2022
2023c69
[WIP] Add support for TensorFlow export
michaelbenayoun Sep 27, 2022
180d7dd
[WIP] onnx export
michaelbenayoun Sep 27, 2022
e63a030
Update file headers
michaelbenayoun Sep 27, 2022
91dca2e
[WIP] cache features.py
michaelbenayoun Oct 10, 2022
a967899
Refactor features.py
michaelbenayoun Oct 10, 2022
f0b2505
Bert like models ready
michaelbenayoun Oct 10, 2022
1c88966
Apply suggestions
michaelbenayoun Oct 10, 2022
e90cb18
Clean up
michaelbenayoun Oct 11, 2022
b0bdd46
Backup input_generators.py version
michaelbenayoun Oct 12, 2022
2cab267
Most seq2seq models work
michaelbenayoun Oct 24, 2022
c6f043f
Add vision models
michaelbenayoun Oct 24, 2022
25b8c5c
Support for GPT-J
michaelbenayoun Oct 24, 2022
8bf6c80
CLIP is working
michaelbenayoun Oct 24, 2022
1a85ef9
Add support for the remaining models
michaelbenayoun Oct 26, 2022
918947f
Make style
michaelbenayoun Oct 26, 2022
9b8c3ee
Fix workflow and remove commented code
michaelbenayoun Oct 26, 2022
3b81d6c
nit
michaelbenayoun Oct 26, 2022
fd5654d
Renamed dynamic axes
michaelbenayoun Oct 26, 2022
c01f5ec
Fix BART
michaelbenayoun Oct 27, 2022
fd0a302
Drop support for PyTorch 1.10 and lower
michaelbenayoun Oct 27, 2022
25c253f
Add re-order inputs method
michaelbenayoun Oct 27, 2022
fbcbfb4
Small fixes
michaelbenayoun Oct 27, 2022
0bb9af2
Changed FeaturesManager to TasksManager
michaelbenayoun Oct 27, 2022
0e8f19c
Apply suggestions
michaelbenayoun Oct 27, 2022
e5e0152
Fix Bloom export
michaelbenayoun Oct 28, 2022
55f683a
Fix mt5 export
michaelbenayoun Oct 28, 2022
89705c8
Change ATOL_FOR_VALIDATION for ResNet
michaelbenayoun Oct 28, 2022
40ad484
Update github actions for exporters to use pytest
michaelbenayoun Oct 28, 2022
96f2e0a
Styling
michaelbenayoun Oct 28, 2022
057958e
Remove old comment
michaelbenayoun Oct 28, 2022
97e9f08
Fix github action for exporters
michaelbenayoun Oct 28, 2022
217a7f0
Fix tests that are not passing
michaelbenayoun Oct 28, 2022
e1862d1
Styling
michaelbenayoun Oct 28, 2022
3ffb7e5
Change Blenderbot model to use
michaelbenayoun Nov 2, 2022
868e1a2
Change Resnet ATOL value
michaelbenayoun Nov 2, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/test_exporters.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Exporters / Python - Test

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
os: [ubuntu-20.04]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[tests,exporters-tf]
- name: Test with unittest
working-directory: tests
run: |
RUN_SLOW=1 pytest exporters --durations=0
2 changes: 0 additions & 2 deletions .github/workflows/test_fx.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: FX / Python - Test

on:
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/test_onnx.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: ONNX / Python - Test

on:
Expand Down Expand Up @@ -29,7 +27,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[onnxruntime] tensorflow tf2onnx
pip install .[tests,onnxruntime] tensorflow tf2onnx
- name: Test with unittest
working-directory: tests
run: |
Expand Down
15 changes: 15 additions & 0 deletions optimum/exporters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import onnx # noqa
21 changes: 21 additions & 0 deletions optimum/exporters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base exporters config."""

from abc import ABC


class ExportConfig(ABC):
pass
18 changes: 18 additions & 0 deletions optimum/exporters/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast # noqa
from .config import DecoderOnnxConfig, EncoderOnnxConfig, Seq2SeqOnnxConfig # noqa
from .convert import export, validate_model_outputs # noqa
130 changes: 130 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entry point to the optimum.exporters.onnx command line."""

from argparse import ArgumentParser
from pathlib import Path

from transformers import AutoTokenizer

from ...utils import logging
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
from .convert import export, validate_model_outputs


logger = logging.get_logger() # pylint: disable=invalid-name
logger.setLevel(logging.INFO)


def main():
parser = ArgumentParser("Hugging Face Optimum ONNX exporter")
parser.add_argument(
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
)
parser.add_argument(
"--task",
default="default",
help="The type of tasks to export the model with.",
)
parser.add_argument("--opset", type=int, default=None, help="ONNX opset version to export the model with.")
parser.add_argument(
"--atol", type=float, default=None, help="Absolute difference tolerance when validating the model."
)
parser.add_argument(
"--framework",
type=str,
choices=["pt", "tf"],
default=None,
help=(
"The framework to use for the ONNX export."
" If not provided, will attempt to use the local checkpoint's original framework"
" or what is available in the environment."
),
)
parser.add_argument(
"--pad_token_id",
type=int,
default=None,
help=(
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess"
" it."
),
)
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")

# Retrieve CLI arguments
args = parser.parse_args()
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")

if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)

# Allocate the model
model = TasksManager.get_model_from_task(args.task, args.model, framework=args.framework, cache_dir=args.cache_dir)
model_type = model.config.model_type.replace("_", "-")
model_name = getattr(model, "name", None)

onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model_type, "onnx", task=args.task, model_name=model_name
)
onnx_config = onnx_config_constructor(model.config)

needs_pad_token_id = (
isinstance(onnx_config, OnnxConfigWithPast)
and getattr(model.config, "pad_token_id", None) is None
and args.task in ["sequence_classification"]
)
if needs_pad_token_id:
if args.pad_token_id is not None:
model.config.pad_token_id = args.pad_token_id
else:
try:
tok = AutoTokenizer.from_pretrained(args.model)
model.config.pad_token_id = tok.pad_token_id
except Exception:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
)

# Ensure the requested opset is sufficient
if args.opset is None:
args.opset = onnx_config.DEFAULT_ONNX_OPSET

if args.opset < onnx_config.DEFAULT_ONNX_OPSET:
raise ValueError(
f"Opset {args.opset} is not sufficient to export {model.config.model_type}. "
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required."
)

onnx_inputs, onnx_outputs = export(
model,
onnx_config,
args.opset,
args.output,
)

if args.atol is None:
args.atol = onnx_config.ATOL_FOR_VALIDATION
if isinstance(args.atol, dict):
args.atol = args.atol[args.task.replace("-with-past", "")]

validate_model_outputs(onnx_config, model, args.output, onnx_outputs, args.atol)
logger.info(f"All good, model saved at: {args.output.as_posix()}")


if __name__ == "__main__":
main()
Loading