diff --git a/build/libs_win.txt b/build/libs_win.txt index e815e645..2b0baca8 100644 --- a/build/libs_win.txt +++ b/build/libs_win.txt @@ -14,3 +14,4 @@ TensorFlow.NET.dll NumSharp.Core.dll System.Drawing.Common.dll Microsoft.ML.* +onnxruntime.dll diff --git a/src/DotNetBridge/Bridge.cs b/src/DotNetBridge/Bridge.cs index 30450540..89e7e652 100644 --- a/src/DotNetBridge/Bridge.cs +++ b/src/DotNetBridge/Bridge.cs @@ -298,7 +298,7 @@ private static unsafe int GenericExec(EnvironmentBlock* penv, sbyte* psz, int cd //env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly); // ML.PipelineInference env.ComponentCatalog.RegisterAssembly(typeof(DataViewReference).Assembly); env.ComponentCatalog.RegisterAssembly(typeof(ImageLoadingTransformer).Assembly); - //env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly); + env.ComponentCatalog.RegisterAssembly(typeof(OnnxExportExtensions).Assembly); //env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesProcessingEntryPoints).Assembly); //env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly); env.ComponentCatalog.RegisterAssembly(typeof(SsaChangePointDetector).Assembly); diff --git a/src/DotNetBridge/DotNetBridge.csproj b/src/DotNetBridge/DotNetBridge.csproj index 67ba3209..18b73d32 100644 --- a/src/DotNetBridge/DotNetBridge.csproj +++ b/src/DotNetBridge/DotNetBridge.csproj @@ -38,6 +38,7 @@ + diff --git a/src/DotNetBridge/Entrypoints.cs b/src/DotNetBridge/Entrypoints.cs index 535d9d75..6f7f8d0c 100644 --- a/src/DotNetBridge/Entrypoints.cs +++ b/src/DotNetBridge/Entrypoints.cs @@ -178,5 +178,59 @@ public static ScoringTransformOutput Score(IHostEnvironment env, ScoringTransfor }; } + + public sealed class OnnxTransformInput : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "Path to the onnx model file.", ShortName = "model", SortOrder = 0)] + public string ModelFile; + + [Argument(ArgumentType.Multiple, HelpText = "Name of the input column.", SortOrder = 1)] + public string[] InputColumns; + + [Argument(ArgumentType.Multiple, HelpText = "Name of the output column.", SortOrder = 2)] + public string[] OutputColumns; + + [Argument(ArgumentType.AtMostOnce, HelpText = "GPU device id to run on (e.g. 0,1,..). Null for CPU. Requires CUDA 9.1.", SortOrder = 3)] + public int? GpuDeviceId = null; + + [Argument(ArgumentType.AtMostOnce, HelpText = "If true, resumes execution on CPU upon GPU error. If false, will raise the GPU execption.", SortOrder = 4)] + public bool FallbackToCpu = false; + } + + public sealed class OnnxTransformOutput + { + [TlcModule.Output(Desc = "ONNX transformed dataset", SortOrder = 1)] + public IDataView OutputData; + + [TlcModule.Output(Desc = "Transform model", SortOrder = 2)] + public TransformModel Model; + } + + [TlcModule.EntryPoint(Name = "Models.OnnxTransformer", + Desc = "Applies an ONNX model to a dataset.", + UserName = "Onnx Transformer", + ShortName = "onnx-xf")] + public static OnnxTransformOutput ApplyOnnxModel(IHostEnvironment env, OnnxTransformInput input) + { + var host = EntryPointUtils.CheckArgsAndCreateHost(env, "OnnxTransform", input); + + var inputColumns = input.InputColumns ?? (Array.Empty()); + var outputColumns = input.OutputColumns ?? (Array.Empty()); + + var transformsCatalog = new TransformsCatalog(host); + var onnxScoringEstimator = OnnxCatalog.ApplyOnnxModel(transformsCatalog, + outputColumns, + inputColumns, + input.ModelFile, + input.GpuDeviceId, + input.FallbackToCpu); + + var view = onnxScoringEstimator.Fit(input.Data).Transform(input.Data); + return new OnnxTransformOutput() + { + Model = new TransformModelImpl(host, view, input.Data), + OutputData = view + }; + } } } diff --git a/src/DotNetBridge/ManifestUtils.cs b/src/DotNetBridge/ManifestUtils.cs index c01b8480..b566cf2f 100644 --- a/src/DotNetBridge/ManifestUtils.cs +++ b/src/DotNetBridge/ManifestUtils.cs @@ -43,6 +43,7 @@ public static class ManifestUtils typeof(ImageLoadingTransformer), typeof(SymbolicSgdLogisticRegressionBinaryTrainer), typeof(OnnxContext), + typeof(OnnxExportExtensions), typeof(SsaForecastingTransformer), typeof(VariableColumnTransform), typeof(DateTimeTransformer) diff --git a/src/Platforms/build.csproj b/src/Platforms/build.csproj index 6a0b7ab7..b4ff1889 100644 --- a/src/Platforms/build.csproj +++ b/src/Platforms/build.csproj @@ -17,6 +17,7 @@ + diff --git a/src/python/nimbusml.pyproj b/src/python/nimbusml.pyproj index c80553c0..65925b6c 100644 --- a/src/python/nimbusml.pyproj +++ b/src/python/nimbusml.pyproj @@ -322,6 +322,7 @@ + @@ -359,6 +360,8 @@ + + @@ -672,6 +675,7 @@ + @@ -747,6 +751,7 @@ + diff --git a/src/python/nimbusml/base_predictor.py b/src/python/nimbusml/base_predictor.py index e619b115..1c694312 100644 --- a/src/python/nimbusml/base_predictor.py +++ b/src/python/nimbusml/base_predictor.py @@ -178,7 +178,6 @@ def summary(self): self.model_summary_ = pipeline.summary() return self.model_summary_ - @trace def _get_implicit_transforms( self, features, @@ -354,3 +353,20 @@ def _get_graph_nodes( row_group_column_name=group_id_column) graph_nodes['learner_node'] = [learner_node] return graph_nodes, learner_features + + @trace + def export_to_onnx(self, *args, **kwargs): + """ + Export the model to the ONNX format. + + See :py:meth:`nimbusml.Pipeline.export_to_onnx` for accepted arguments. + """ + if not hasattr(self, 'model_') \ + or self.model_ is None \ + or not os.path.isfile(self.model_): + + raise ValueError("Model is not fitted. Train or load a model before " + "export_to_onnx().") + + pipeline = Pipeline([self], model=self.model_) + pipeline.export_to_onnx(*args, **kwargs) diff --git a/src/python/nimbusml/base_transform.py b/src/python/nimbusml/base_transform.py index b227d567..f0c4f861 100644 --- a/src/python/nimbusml/base_transform.py +++ b/src/python/nimbusml/base_transform.py @@ -124,3 +124,20 @@ def transform(self, X, as_binary_data_stream=False, **params): data = pipeline.transform( X, as_binary_data_stream=as_binary_data_stream, **params) return data + + @trace + def export_to_onnx(self, *args, **kwargs): + """ + Export the model to the ONNX format. + + See :py:meth:`nimbusml.Pipeline.export_to_onnx` for accepted arguments. + """ + if not hasattr(self, 'model_') \ + or self.model_ is None \ + or not os.path.isfile(self.model_): + + raise ValueError("Model is not fitted. Train or load a model before " + "export_to_onnx().") + + pipeline = Pipeline([self], model=self.model_) + pipeline.export_to_onnx(*args, **kwargs) diff --git a/src/python/nimbusml/internal/core/preprocessing/onnxrunner.py b/src/python/nimbusml/internal/core/preprocessing/onnxrunner.py new file mode 100644 index 00000000..34ed46ba --- /dev/null +++ b/src/python/nimbusml/internal/core/preprocessing/onnxrunner.py @@ -0,0 +1,71 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +OnnxRunner +""" + +__all__ = ["OnnxRunner"] + + +from ...entrypoints.models_onnxtransformer import models_onnxtransformer +from ...utils.utils import trace +from ..base_pipeline_item import BasePipelineItem, DefaultSignature + + +class OnnxRunner(BasePipelineItem, DefaultSignature): + """ + **Description** + Applies an ONNX model to a dataset. + + :param model_file: Path to the onnx model file. + + :param input_columns: Name of the input column. + + :param output_columns: Name of the output column. + + :param gpu_device_id: GPU device id to run on (e.g. 0,1,..). Null for CPU. + Requires CUDA 9.1. + + :param fallback_to_cpu: If true, resumes execution on CPU upon GPU error. + If false, will raise the GPU execption. + + :param params: Additional arguments sent to compute engine. + + """ + + @trace + def __init__( + self, + model_file, + input_columns=None, + output_columns=None, + gpu_device_id=None, + fallback_to_cpu=False, + **params): + BasePipelineItem.__init__( + self, type='transform', **params) + + self.model_file = model_file + self.input_columns = input_columns + self.output_columns = output_columns + self.gpu_device_id = gpu_device_id + self.fallback_to_cpu = fallback_to_cpu + + @property + def _entrypoint(self): + return models_onnxtransformer + + @trace + def _get_node(self, **all_args): + algo_args = dict( + model_file=self.model_file, + input_columns=self.input_columns, + output_columns=self.output_columns, + gpu_device_id=self.gpu_device_id, + fallback_to_cpu=self.fallback_to_cpu) + + all_args.update(algo_args) + return self._entrypoint(**all_args) diff --git a/src/python/nimbusml/internal/entrypoints/models_onnxconverter.py b/src/python/nimbusml/internal/entrypoints/models_onnxconverter.py new file mode 100644 index 00000000..3c080eb6 --- /dev/null +++ b/src/python/nimbusml/internal/entrypoints/models_onnxconverter.py @@ -0,0 +1,116 @@ +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +Models.OnnxConverter +""" + + +from ..utils.entrypoints import EntryPoint +from ..utils.utils import try_set, unlist + + +def models_onnxconverter( + onnx, + data_file=None, + json=None, + name=None, + domain=None, + inputs_to_drop=None, + outputs_to_drop=None, + model=None, + onnx_version='Stable', + predictive_model=None, + **params): + """ + **Description** + Converts the model to ONNX format. + + :param data_file: The data file (inputs). + :param onnx: The path to write the output ONNX to. (inputs). + :param json: The path to write the output JSON to. (inputs). + :param name: The 'name' property in the output ONNX. By default + this will be the ONNX extension-less name. (inputs). + :param domain: The 'domain' property in the output ONNX. + (inputs). + :param inputs_to_drop: Array of input column names to drop + (inputs). + :param outputs_to_drop: Array of output column names to drop + (inputs). + :param model: Model that needs to be converted to ONNX format. + (inputs). + :param onnx_version: The targeted ONNX version. It can be either + "Stable" or "Experimental". If "Experimental" is used, + produced model can contain components that is not officially + supported in ONNX standard. (inputs). + :param predictive_model: Predictor model that needs to be + converted to ONNX format. (inputs). + """ + + entrypoint_name = 'Models.OnnxConverter' + inputs = {} + outputs = {} + + if data_file is not None: + inputs['DataFile'] = try_set( + obj=data_file, + none_acceptable=True, + is_of_type=str) + if onnx is not None: + inputs['Onnx'] = try_set( + obj=onnx, + none_acceptable=False, + is_of_type=str) + if json is not None: + inputs['Json'] = try_set( + obj=json, + none_acceptable=True, + is_of_type=str) + if name is not None: + inputs['Name'] = try_set( + obj=name, + none_acceptable=True, + is_of_type=str, + is_column=True) + if domain is not None: + inputs['Domain'] = try_set( + obj=domain, + none_acceptable=True, + is_of_type=str) + if inputs_to_drop is not None: + inputs['InputsToDrop'] = try_set( + obj=inputs_to_drop, + none_acceptable=True, + is_of_type=list) + if outputs_to_drop is not None: + inputs['OutputsToDrop'] = try_set( + obj=outputs_to_drop, + none_acceptable=True, + is_of_type=list) + if model is not None: + inputs['Model'] = try_set( + obj=model, + none_acceptable=True, + is_of_type=str) + if onnx_version is not None: + inputs['OnnxVersion'] = try_set( + obj=onnx_version, + none_acceptable=True, + is_of_type=str, + values=[ + 'Stable', + 'Experimental']) + if predictive_model is not None: + inputs['PredictiveModel'] = try_set( + obj=predictive_model, none_acceptable=True, is_of_type=str) + + input_variables = { + x for x in unlist(inputs.values()) + if isinstance(x, str) and x.startswith("$")} + output_variables = { + x for x in unlist(outputs.values()) + if isinstance(x, str) and x.startswith("$")} + + entrypoint = EntryPoint( + name=entrypoint_name, inputs=inputs, outputs=outputs, + input_variables=input_variables, + output_variables=output_variables) + return entrypoint diff --git a/src/python/nimbusml/internal/entrypoints/models_onnxtransformer.py b/src/python/nimbusml/internal/entrypoints/models_onnxtransformer.py new file mode 100644 index 00000000..173c976a --- /dev/null +++ b/src/python/nimbusml/internal/entrypoints/models_onnxtransformer.py @@ -0,0 +1,96 @@ +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +Models.OnnxTransformer +""" + +import numbers + +from ..utils.entrypoints import EntryPoint +from ..utils.utils import try_set, unlist + + +def models_onnxtransformer( + model_file, + data, + output_data=None, + model=None, + input_columns=None, + output_columns=None, + gpu_device_id=None, + fallback_to_cpu=False, + **params): + """ + **Description** + Applies an ONNX model to a dataset. + + :param model_file: Path to the onnx model file. (inputs). + :param input_columns: Name of the input column. (inputs). + :param data: Input dataset (inputs). + :param output_columns: Name of the output column. (inputs). + :param gpu_device_id: GPU device id to run on (e.g. 0,1,..). Null + for CPU. Requires CUDA 9.1. (inputs). + :param fallback_to_cpu: If true, resumes execution on CPU upon + GPU error. If false, will raise the GPU execption. (inputs). + :param output_data: ONNX transformed dataset (outputs). + :param model: Transform model (outputs). + """ + + entrypoint_name = 'Models.OnnxTransformer' + inputs = {} + outputs = {} + + if model_file is not None: + inputs['ModelFile'] = try_set( + obj=model_file, + none_acceptable=False, + is_of_type=str) + if input_columns is not None: + inputs['InputColumns'] = try_set( + obj=input_columns, + none_acceptable=True, + is_of_type=list, + is_column=True) + if data is not None: + inputs['Data'] = try_set( + obj=data, + none_acceptable=False, + is_of_type=str) + if output_columns is not None: + inputs['OutputColumns'] = try_set( + obj=output_columns, + none_acceptable=True, + is_of_type=list, + is_column=True) + if gpu_device_id is not None: + inputs['GpuDeviceId'] = try_set( + obj=gpu_device_id, + none_acceptable=True, + is_of_type=numbers.Real) + if fallback_to_cpu is not None: + inputs['FallbackToCpu'] = try_set( + obj=fallback_to_cpu, + none_acceptable=True, + is_of_type=bool) + if output_data is not None: + outputs['OutputData'] = try_set( + obj=output_data, + none_acceptable=False, + is_of_type=str) + if model is not None: + outputs['Model'] = try_set( + obj=model, + none_acceptable=False, + is_of_type=str) + + input_variables = { + x for x in unlist(inputs.values()) + if isinstance(x, str) and x.startswith("$")} + output_variables = { + x for x in unlist(outputs.values()) + if isinstance(x, str) and x.startswith("$")} + + entrypoint = EntryPoint( + name=entrypoint_name, inputs=inputs, outputs=outputs, + input_variables=input_variables, + output_variables=output_variables) + return entrypoint diff --git a/src/python/nimbusml/pipeline.py b/src/python/nimbusml/pipeline.py index 6eb190e3..d7337d37 100644 --- a/src/python/nimbusml/pipeline.py +++ b/src/python/nimbusml/pipeline.py @@ -39,6 +39,7 @@ from .internal.entrypoints.models_regressionevaluator import \ models_regressionevaluator from .internal.entrypoints.models_summarizer import models_summarizer +from .internal.entrypoints.models_onnxconverter import models_onnxconverter from .internal.entrypoints.models_schema import models_schema from .internal.entrypoints.transforms_datasetscorerex import \ transforms_datasetscorerex @@ -2519,6 +2520,96 @@ def __setstate__(self, state): else: raise ValueError('Pipeline version not supported.') + @trace + def export_to_onnx(self, + dst, + domain, + dst_json=None, + name=None, + data_file=None, + inputs_to_drop=None, + outputs_to_drop=None, + onnx_version="Stable", + verbose=0): + """ + Export the model to the ONNX format. + + :param str dst: The path to write the output ONNX to. + :param str domain: A reverse-DNS name to indicate the model + namespace or domain, for example, 'org.onnx'. + :param str dst_json: The path to write the output ONNX to + in JSON format. + :param name: The 'graph.name' property in the output ONNX. By default + this will be the ONNX extension-less name. (inputs). + :param data_file: The data file (inputs). + :param inputs_to_drop: Array of input column names to drop + (inputs). + :param outputs_to_drop: Array of output column names to drop + (inputs). + :param onnx_version: The targeted ONNX version. It can be either + "Stable" or "Experimental". If "Experimental" is used, + produced model can contain components that is not officially + supported in ONNX standard. (inputs). + """ + if not domain: + raise ValueError("domain argument must be specified and not empty.") + + if not self._is_fitted: + raise ValueError("Model is not fitted. Train or load a model before " + "export_to_onnx().") + + # start the clock! + start_time = time.time() + + onnx_converter_args = { + 'onnx': dst, + 'json': dst_json, + 'domain': domain, + 'name': name, + 'data_file': data_file, + 'inputs_to_drop': inputs_to_drop, + 'outputs_to_drop': outputs_to_drop, + 'onnx_version': onnx_version + } + + if (len(self.steps) > 0) and (self.last_node.type != "transform"): + onnx_converter_args['predictive_model'] = "$model" + else: + onnx_converter_args['model'] = "$model" + + onnx_converter_node = models_onnxconverter(**onnx_converter_args) + + inputs = dict([('model', self.model)]) + outputs = dict() + + graph = Graph( + inputs, + outputs, + False, + onnx_converter_node) + + class_name = type(self).__name__ + method_name = inspect.currentframe().f_code.co_name + telemetry_info = ".".join([class_name, method_name]) + + try: + graph.run( + X=None, + y=None, + random_state=self.random_state, + model=self.model, + verbose=verbose, + is_summary=False, + no_input_data=True, + telemetry_info=telemetry_info) + except RuntimeError as e: + self._run_time = time.time() - start_time + raise e + + # stop the clock + self._run_time = time.time() - start_time + self._write_csv_time = graph._write_csv_time + @trace def score( self, diff --git a/src/python/nimbusml/preprocessing/__init__.py b/src/python/nimbusml/preprocessing/__init__.py index 728327be..202eb15d 100644 --- a/src/python/nimbusml/preprocessing/__init__.py +++ b/src/python/nimbusml/preprocessing/__init__.py @@ -2,6 +2,7 @@ from .tokey import ToKey from .tensorflowscorer import TensorFlowScorer from .datasettransformer import DatasetTransformer +from .onnxrunner import OnnxRunner from .datetimesplitter import DateTimeSplitter from .tokeyimputer import ToKeyImputer from .tostring import ToString @@ -13,5 +14,6 @@ 'ToKeyImputer', 'ToString', 'TensorFlowScorer', - 'DatasetTransformer' + 'DatasetTransformer', + 'OnnxRunner' ] diff --git a/src/python/nimbusml/preprocessing/onnxrunner.py b/src/python/nimbusml/preprocessing/onnxrunner.py new file mode 100644 index 00000000..2df2ac75 --- /dev/null +++ b/src/python/nimbusml/preprocessing/onnxrunner.py @@ -0,0 +1,82 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- +# - Generated by tools/entrypoint_compiler.py: do not edit by hand +""" +OnnxRunner +""" + +__all__ = ["OnnxRunner"] + + +from sklearn.base import TransformerMixin + +from ..base_transform import BaseTransform +from ..internal.core.preprocessing.onnxrunner import OnnxRunner as core +from ..internal.utils.utils import trace + + +class OnnxRunner(core, BaseTransform, TransformerMixin): + """ + **Description** + Applies an ONNX model to a dataset. + + :param columns: see `Columns `_. + + :param model_file: Path to the onnx model file. + + :param input_columns: Name of the input column. + + :param output_columns: Name of the output column. + + :param gpu_device_id: GPU device id to run on (e.g. 0,1,..). Null for CPU. + Requires CUDA 9.1. + + :param fallback_to_cpu: If true, resumes execution on CPU upon GPU error. + If false, will raise the GPU execption. + + :param params: Additional arguments sent to compute engine. + + """ + + @trace + def __init__( + self, + model_file, + input_columns=None, + output_columns=None, + gpu_device_id=None, + fallback_to_cpu=False, + columns=None, + **params): + + if columns: + params['columns'] = columns + if columns: + input_columns = sum( + list( + columns.values()), + []) if isinstance( + list( + columns.values())[0], + list) else list( + columns.values()) + if columns: + output_columns = list(columns.keys()) + BaseTransform.__init__(self, **params) + core.__init__( + self, + model_file=model_file, + input_columns=input_columns, + output_columns=output_columns, + gpu_device_id=gpu_device_id, + fallback_to_cpu=fallback_to_cpu, + **params) + self._columns = columns + + def get_params(self, deep=False): + """ + Get the parameters for this operator. + """ + return core.get_params(self) diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py index 18a5ba31..998585f8 100644 --- a/src/python/tests/test_estimator_checks.py +++ b/src/python/tests/test_estimator_checks.py @@ -280,6 +280,7 @@ # skip SymSgdBinaryClassifier for now, because of crashes. 'SymSgdBinaryClassifier', 'DatasetTransformer', + 'OnnxRunner' 'TimeSeriesImputer' ]) diff --git a/src/python/tests_extended/test_export_to_onnx.py b/src/python/tests_extended/test_export_to_onnx.py new file mode 100644 index 00000000..3cf5fdb4 --- /dev/null +++ b/src/python/tests_extended/test_export_to_onnx.py @@ -0,0 +1,437 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------------------------- +""" +Verify onnx export support +""" +import contextlib +import io +import json +import os +import sys +import tempfile +import numpy as np +import pandas as pd +import pprint + +from nimbusml import Pipeline +from nimbusml.cluster import KMeansPlusPlus +from nimbusml.datasets import get_dataset +from nimbusml.datasets.image import get_RevolutionAnalyticslogo, get_Microsoftlogo +from nimbusml.decomposition import PcaTransformer, PcaAnomalyDetector +from nimbusml.ensemble import FastForestBinaryClassifier, FastTreesTweedieRegressor, LightGbmRanker +from nimbusml.feature_extraction.categorical import OneHotVectorizer, OneHotHashVectorizer +from nimbusml.feature_extraction.image import Loader, Resizer, PixelExtractor +from nimbusml.feature_extraction.text import NGramFeaturizer +from nimbusml.feature_extraction.text.extractor import Ngram +from nimbusml.feature_selection import CountSelector, MutualInformationSelector +from nimbusml.linear_model import FastLinearBinaryClassifier +from nimbusml.naive_bayes import NaiveBayesClassifier +from nimbusml.preprocessing import TensorFlowScorer, FromKey, ToKey +from nimbusml.preprocessing.filter import SkipFilter, TakeFilter, RangeFilter +from nimbusml.preprocessing.missing_values import Filter, Handler, Indicator +from nimbusml.preprocessing.normalization import Binner, GlobalContrastRowScaler, LpScaler +from nimbusml.preprocessing.schema import (ColumnConcatenator, TypeConverter, + ColumnDuplicator, ColumnSelector) +from nimbusml.preprocessing.text import CharTokenizer, WordTokenizer +from nimbusml.timeseries import (IidSpikeDetector, IidChangePointDetector, + SsaSpikeDetector, SsaChangePointDetector, + SsaForecaster) + + +SHOW_ONNX_JSON = False + +script_path = os.path.realpath(__file__) +script_dir = os.path.dirname(script_path) + +# Sepal_Length Sepal_Width Petal_Length Petal_Width Label Species Setosa +# 0 5.1 3.5 1.4 0.2 0 setosa 1.0 +# 1 4.9 3.0 1.4 0.2 0 setosa 1.0 +iris_df = get_dataset("iris").as_df() +iris_df.drop(['Species'], axis=1, inplace=True) + +iris_no_label_df = iris_df.drop(['Label'], axis=1) +iris_binary_df = iris_no_label_df.rename(columns={'Setosa': 'Label'}) +iris_regression_df = iris_no_label_df.drop(['Setosa'], axis=1).rename(columns={'Petal_Width': 'Label'}) + +# Unnamed: 0 education age parity induced case spontaneous stratum pooled.stratum education_str +# 0 1 0.0 26.0 6.0 1.0 1.0 2.0 1.0 3.0 0-5yrs +# 1 2 0.0 42.0 1.0 1.0 1.0 0.0 2.0 1.0 0-5yrs +infert_df = get_dataset("infert").as_df() +infert_df.columns = [i.replace(': ', '') for i in infert_df.columns] +infert_df.rename(columns={'case': 'Label'}, inplace=True) + +infert_onehot_df = (OneHotVectorizer() << 'education_str').fit_transform(infert_df) + +# rank group carrier price Class dep_day nbr_stops duration +# 0 2 1 AA 240 3 1 0 12.0 +# 1 1 1 AA 300 3 0 1 15.0 +file_path = get_dataset("gen_tickettrain").as_filepath() +gen_tt_df = pd.read_csv(file_path) +gen_tt_df['group'] = gen_tt_df['group'].astype(np.uint32) + +# Unnamed: 0 Label Solar_R Wind Temp Month Day +# 0 1 41.0 190.0 7.4 67 5 1 +# 1 2 36.0 118.0 8.0 72 5 2 +airquality_df = get_dataset("airquality").as_df().fillna(0) +airquality_df = airquality_df[airquality_df.Ozone.notnull()] + +# Sentiment SentimentText +# 0 1 ==RUDE== Dude, you are rude upload that carl... +# 1 1 == OK! == IM GOING TO VANDALIZE WILD ONES W... +file_path = get_dataset("wiki_detox_train").as_filepath() +wiki_detox_df = pd.read_csv(file_path, sep='\t') +wiki_detox_df = wiki_detox_df.head(10) + +# Path Label +# 0 C:\repo\src\python... True +# 1 C:\repo\src\python... False +image_paths_df = pd.DataFrame(data=dict( + Path=[get_RevolutionAnalyticslogo(), get_Microsoftlogo()], + Label=[True, False])) + + +SKIP = { + 'DatasetTransformer', + 'LightLda', + 'OneVsRestClassifier', + 'Sentiment', + 'TensorFlowScorer', + 'TreeFeaturizer', + 'WordEmbedding', + 'OnnxRunner' +} + +INSTANCES = { + 'Binner': Binner(num_bins=3), + 'CharTokenizer': CharTokenizer(columns={'SentimentText_Transform': 'SentimentText'}), + 'ColumnConcatenator': ColumnConcatenator(columns={'Features': [ + 'Sepal_Length', + 'Sepal_Width', + 'Petal_Length', + 'Petal_Width', + 'Setosa']}), + 'ColumnSelector': ColumnSelector(columns=['Sepal_Width', 'Sepal_Length']), + 'ColumnDuplicator': ColumnDuplicator(columns={'dup': 'Sepal_Width'}), + 'CountSelector': CountSelector(count=5, columns=['Sepal_Width']), + 'FastForestBinaryClassifier': FastForestBinaryClassifier(feature=['Sepal_Width', 'Sepal_Length'], + label='Setosa'), + 'FastLinearBinaryClassifier': FastLinearBinaryClassifier(feature=['Sepal_Width', 'Sepal_Length'], + label='Setosa'), + 'FastTreesTweedieRegressor': FastTreesTweedieRegressor(label='Ozone'), + 'Filter': Filter(columns=[ 'Petal_Length', 'Petal_Width']), + 'FromKey': Pipeline([ + ToKey(columns=['Setosa']), + FromKey(columns=['Setosa']) + ]), + # GlobalContrastRowScaler currently requires a vector input to work + 'GlobalContrastRowScaler': Pipeline([ + ColumnConcatenator() << { + 'concated_columns': [ + 'Petal_Length', + 'Sepal_Width', + 'Sepal_Length']}, + GlobalContrastRowScaler(columns={'normed_columns': 'concated_columns'}) + ]), + 'Handler': Handler(replace_with='Mean', columns={'NewVals': 'Sepal_Length'}), + 'IidSpikeDetector': IidSpikeDetector(columns=['Sepal_Length']), + 'IidChangePointDetector': IidChangePointDetector(columns=['Sepal_Length']), + 'Indicator': Indicator(columns={'Has_Nan': 'Petal_Length'}), + 'KMeansPlusPlus': KMeansPlusPlus(n_clusters=3, feature=['Sepal_Width', 'Sepal_Length']), + 'LightGbmRanker': LightGbmRanker(feature=['Class', 'dep_day', 'duration'], + label='rank', + group_id='group'), + 'Loader': Loader(columns={'ImgPath': 'Path'}), + 'LpScaler': Pipeline([ + ColumnConcatenator() << { + 'concated_columns': [ + 'Petal_Length', + 'Sepal_Width', + 'Sepal_Length']}, + LpScaler(columns={'normed_columns': 'concated_columns'}) + ]), + 'MutualInformationSelector': Pipeline([ + ColumnConcatenator(columns={'Features': ['Sepal_Width', 'Sepal_Length', 'Petal_Width']}), + MutualInformationSelector( + columns='Features', + label='Label', + slots_in_output=2) # only accept one column + ]), + 'NaiveBayesClassifier': NaiveBayesClassifier(feature=['Sepal_Width', 'Sepal_Length']), + 'NGramFeaturizer': NGramFeaturizer(word_feature_extractor=Ngram(), + columns={ 'features': ['SentimentText']}), + 'OneHotHashVectorizer': OneHotHashVectorizer(columns=['education_str']), + 'OneHotVectorizer': OneHotVectorizer(columns=['education_str']), + 'PcaAnomalyDetector': PcaAnomalyDetector(rank=3), + 'PcaTransformer': PcaTransformer(rank=2), + 'PixelExtractor': Pipeline([ + Loader(columns={'ImgPath': 'Path'}), + PixelExtractor(columns={'ImgPixels': 'ImgPath'}), + ]), + 'Resizer': Pipeline([ + Loader(columns={'ImgPath': 'Path'}), + Resizer(image_width=227, image_height=227, + columns={'ImgResize': 'ImgPath'}) + ]), + 'SkipFilter': SkipFilter(count=5), + 'SsaSpikeDetector': SsaSpikeDetector(columns=['Sepal_Length'], + seasonal_window_size=2), + 'SsaChangePointDetector': SsaChangePointDetector(columns=['Sepal_Length'], + seasonal_window_size=2), + 'SsaForecaster': SsaForecaster(columns=['Sepal_Length'], + window_size=2, + series_length=5, + train_size=5, + horizon=1), + 'RangeFilter': RangeFilter(min=5.0, max=5.1, columns=['Sepal_Length']), + 'TakeFilter': TakeFilter(count=100), + 'TensorFlowScorer': TensorFlowScorer( + model_location=os.path.join( + script_dir, + '..', + 'nimbusml', + 'examples', + 'frozen_saved_model.pb'), + columns={'c': ['a', 'b']}), + 'ToKey': ToKey(columns={'edu_1': 'education'}), + 'TypeConverter': TypeConverter(columns=['age'], result_type='R4'), + 'WordTokenizer': WordTokenizer(char_array_term_separators=[" "]) << {'wt': 'SentimentText'} +} + +DATASETS = { + 'AveragedPerceptronBinaryClassifier': infert_onehot_df, + 'Binner': iris_no_label_df, + 'BootstrapSampler': infert_df, + 'CharTokenizer': wiki_detox_df, + 'EnsembleRegressor': iris_regression_df, + 'FactorizationMachineBinaryClassifier': iris_binary_df, + 'FastForestBinaryClassifier': iris_no_label_df, + 'FastForestRegressor': iris_regression_df, + 'FastLinearBinaryClassifier': iris_no_label_df, + 'FastLinearClassifier': iris_binary_df, + 'FastLinearRegressor': iris_regression_df, + 'FastTreesBinaryClassifier': iris_binary_df, + 'FastTreesRegressor': iris_regression_df, + 'FastTreesTweedieRegressor': airquality_df, + 'Filter': iris_no_label_df, + 'GamBinaryClassifier': iris_binary_df, + 'GamRegressor': iris_regression_df, + 'GlobalContrastRowScaler': iris_df.astype(np.float32), + 'LightGbmRanker': gen_tt_df, + 'LinearSvmBinaryClassifier': iris_binary_df, + 'Loader': image_paths_df, + 'LogisticRegressionBinaryClassifier': iris_binary_df, + 'LogisticRegressionClassifier': iris_binary_df, + 'LogMeanVarianceScaler': iris_no_label_df, + 'LpScaler': iris_no_label_df.drop(['Setosa'], axis=1).astype(np.float32), + 'MeanVarianceScaler': iris_no_label_df, + 'MinMaxScaler': iris_no_label_df, + 'NGramFeaturizer': wiki_detox_df, + 'OneHotHashVectorizer': infert_df, + 'OneHotVectorizer': infert_df, + 'OnlineGradientDescentRegressor': iris_regression_df, + 'OrdinaryLeastSquaresRegressor': iris_regression_df, + 'PcaAnomalyDetector': iris_no_label_df, + 'PcaTransformer': iris_regression_df, + 'PixelExtractor': image_paths_df, + 'PoissonRegressionRegressor': iris_regression_df, + 'Resizer': image_paths_df, + 'SgdBinaryClassifier': iris_binary_df, + 'SymSgdBinaryClassifier': iris_binary_df, + 'ToKey': infert_df, + 'TypeConverter': infert_onehot_df, + 'WordTokenizer': wiki_detox_df +} + +REQUIRES_EXPERIMENTAL = { +} + +SUPPORTED_ESTIMATORS = { + 'ColumnConcatenator', + 'ColumnDuplicator', + 'CountSelector', + 'EnsembleClassifier', + 'EnsembleRegressor', + 'FastForestRegressor', + 'FastLinearRegressor', + 'FastTreesRegressor', + 'FastTreesTweedieRegressor', + 'GamRegressor', + 'Indicator', + 'KMeansPlusPlus', + 'LightGbmBinaryClassifier', + 'LightGbmClassifier', + 'LightGbmRegressor', + 'LpScaler', + 'MeanVarianceScaler', + 'MinMaxScaler', + 'NaiveBayesClassifier', + 'OneHotVectorizer', + 'OnlineGradientDescentRegressor', + 'OrdinaryLeastSquaresRegressor', + 'PcaAnomalyDetector', + 'PoissonRegressionRegressor', + 'PrefixColumnConcatenator', + 'TypeConverter', + 'WordTokenizer' +} + + +class CaptureOutputContext(): + """ + Context which can be used for + capturing stdout and stderr. + """ + def __enter__(self): + self.orig_stdout = sys.stdout + self.orig_stderr = sys.stderr + self.stdout_capturer = io.StringIO() + self.stderr_capturer = io.StringIO() + sys.stdout = self.stdout_capturer + sys.stderr = self.stderr_capturer + return self + + def __exit__(self, *args): + sys.stdout = self.orig_stdout + sys.stderr = self.orig_stderr + self.stdout = self.stdout_capturer.getvalue() + self.stderr = self.stderr_capturer.getvalue() + + if self.stdout: + print(self.stdout) + + if self.stderr: + print(self.stderr) + + # free up some memory + del self.stdout_capturer + del self.stderr_capturer + + +def get_tmp_file(suffix=None): + fd, file_name = tempfile.mkstemp(suffix=suffix) + fl = os.fdopen(fd, 'w') + fl.close() + return file_name + + +def get_file_size(file_path): + file_size = 0 + try: + file_size = os.path.getsize(file_path) + except: + pass + return file_size + + +def load_json(file_path): + with open(file_path) as f: + lines = f.readlines() + lines = [l for l in lines if not l.strip().startswith('#')] + content_without_comments = '\n'.join(lines) + return json.loads(content_without_comments) + + +def test_export_to_onnx(estimator, class_name): + """ + Fit and test an estimator and determine + if it supports exporting to the ONNX format. + """ + onnx_path = get_tmp_file('.onnx') + onnx_json_path = get_tmp_file('.onnx.json') + + output = None + exported = False + + try: + dataset = DATASETS.get(class_name, iris_df) + estimator.fit(dataset) + + onnx_version = 'Experimental' if class_name in REQUIRES_EXPERIMENTAL else 'Stable' + + with CaptureOutputContext() as output: + estimator.export_to_onnx(onnx_path, + 'com.microsoft.ml', + dst_json=onnx_json_path, + onnx_version=onnx_version) + except Exception as e: + print(e) + + onnx_file_size = get_file_size(onnx_path) + onnx_json_file_size = get_file_size(onnx_json_path) + + if (output and + (onnx_file_size != 0) and + (onnx_json_file_size != 0) and + (not 'cannot save itself as ONNX' in output.stdout)): + exported = True + + if exported and SHOW_ONNX_JSON: + with open(onnx_json_path) as f: + print(json.dumps(json.load(f), indent=4)) + + os.remove(onnx_path) + os.remove(onnx_json_path) + return exported + + +manifest_diff = os.path.join(script_dir, '..', 'tools', 'manifest_diff.json') +entry_points = load_json(manifest_diff)['EntryPoints'] +entry_points = sorted(entry_points, key=lambda ep: ep['NewName']) + +exportable_estimators = set() +exportable_experimental_estimators = set() +unexportable_estimators = set() + +for entry_point in entry_points: + class_name = entry_point['NewName'] + + print('\n===========> %s' % class_name) + + if class_name in SKIP: + print("skipped") + continue + + mod = __import__('nimbusml.' + entry_point['Module'], + fromlist=[str(class_name)]) + + if class_name in INSTANCES: + estimator = INSTANCES[class_name] + else: + the_class = getattr(mod, class_name) + estimator = the_class() + + result = test_export_to_onnx(estimator, class_name) + + if result: + if class_name in REQUIRES_EXPERIMENTAL: + exportable_experimental_estimators.add(class_name) + else: + exportable_estimators.add(class_name) + + print('Estimator successfully exported to ONNX.') + + else: + unexportable_estimators.add(class_name) + print('Estimator could NOT be exported to ONNX.') + +print('\nThe following estimators were skipped: ') +pprint.pprint(sorted(SKIP)) + +print('\nThe following estimators were successfully exported to ONNX:') +pprint.pprint(sorted(exportable_estimators)) + +print('\nThe following estimators were successfully exported to experimental ONNX: ') +pprint.pprint(sorted(exportable_experimental_estimators)) + +print('\nThe following estimators could not be exported to ONNX: ') +pprint.pprint(sorted(unexportable_estimators)) + +failed_estimators = SUPPORTED_ESTIMATORS.difference(exportable_estimators) \ + .difference(exportable_experimental_estimators) + +if len(failed_estimators) > 0: + print("The following tests failed exporting to onnx:", sorted(failed_estimators)) + raise RuntimeError("onnx export checks failed") + diff --git a/src/python/tools/manifest.json b/src/python/tools/manifest.json index 5b739e57..491ac877 100644 --- a/src/python/tools/manifest.json +++ b/src/python/tools/manifest.json @@ -2194,6 +2194,203 @@ "ITrainerInput" ] }, + { + "Name": "Models.OnnxConverter", + "Desc": "Converts the model to ONNX format.", + "FriendlyName": "ONNX Converter.", + "ShortName": null, + "Inputs": [ + { + "Name": "DataFile", + "Type": "String", + "Desc": "The data file", + "Aliases": [ + "data" + ], + "Required": false, + "SortOrder": 0.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Onnx", + "Type": "String", + "Desc": "The path to write the output ONNX to.", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Json", + "Type": "String", + "Desc": "The path to write the output JSON to.", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Name", + "Type": "String", + "Desc": "The 'name' property in the output ONNX. By default this will be the ONNX extension-less name.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Domain", + "Type": "String", + "Desc": "The 'domain' property in the output ONNX.", + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "InputsToDrop", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Array of input column names to drop", + "Required": false, + "SortOrder": 6.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "OutputsToDrop", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Array of output column names to drop", + "Required": false, + "SortOrder": 8.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Model that needs to be converted to ONNX format.", + "Required": false, + "SortOrder": 10.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "OnnxVersion", + "Type": { + "Kind": "Enum", + "Values": [ + "Stable", + "Experimental" + ] + }, + "Desc": "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.", + "Required": false, + "SortOrder": 11.0, + "IsNullable": false, + "Default": "Stable" + }, + { + "Name": "PredictiveModel", + "Type": "PredictorModel", + "Desc": "Predictor model that needs to be converted to ONNX format.", + "Required": false, + "SortOrder": 12.0, + "IsNullable": false, + "Default": null + } + ], + "Outputs": [] + }, + { + "Name": "Models.OnnxTransformer", + "Desc": "Applies an ONNX model to a dataset.", + "FriendlyName": "Onnx Transformer", + "ShortName": "onnx-xf", + "Inputs": [ + { + "Name": "ModelFile", + "Type": "String", + "Desc": "Path to the onnx model file.", + "Aliases": [ + "model" + ], + "Required": true, + "SortOrder": 0.0, + "IsNullable": false + }, + { + "Name": "InputColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Name of the input column.", + "Required": false, + "SortOrder": 1.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "OutputColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Name of the output column.", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "GpuDeviceId", + "Type": "Int", + "Desc": "GPU device id to run on (e.g. 0,1,..). Null for CPU. Requires CUDA 9.1.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": true, + "Default": null + }, + { + "Name": "FallbackToCpu", + "Type": "Bool", + "Desc": "If true, resumes execution on CPU upon GPU error. If false, will raise the GPU execption.", + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "ONNX transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ] + }, { "Name": "Models.OvaModelCombiner", "Desc": "Combines a sequence of PredictorModels into a single model", diff --git a/src/python/tools/manifest_diff.json b/src/python/tools/manifest_diff.json index c94a8845..a70489ee 100644 --- a/src/python/tools/manifest_diff.json +++ b/src/python/tools/manifest_diff.json @@ -347,6 +347,12 @@ "Module": "preprocessing", "Type": "Transform" }, + { + "Name": "Models.OnnxTransformer", + "NewName": "OnnxRunner", + "Module": "preprocessing", + "Type": "Transform" + }, { "Name": "Trainers.FieldAwareFactorizationMachineBinaryClassifier", "NewName": "FactorizationMachineBinaryClassifier",