diff --git a/build/libs_win.txt b/build/libs_win.txt
index e815e645..2b0baca8 100644
--- a/build/libs_win.txt
+++ b/build/libs_win.txt
@@ -14,3 +14,4 @@ TensorFlow.NET.dll
NumSharp.Core.dll
System.Drawing.Common.dll
Microsoft.ML.*
+onnxruntime.dll
diff --git a/src/DotNetBridge/Bridge.cs b/src/DotNetBridge/Bridge.cs
index 30450540..89e7e652 100644
--- a/src/DotNetBridge/Bridge.cs
+++ b/src/DotNetBridge/Bridge.cs
@@ -298,7 +298,7 @@ private static unsafe int GenericExec(EnvironmentBlock* penv, sbyte* psz, int cd
//env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly); // ML.PipelineInference
env.ComponentCatalog.RegisterAssembly(typeof(DataViewReference).Assembly);
env.ComponentCatalog.RegisterAssembly(typeof(ImageLoadingTransformer).Assembly);
- //env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly);
+ env.ComponentCatalog.RegisterAssembly(typeof(OnnxExportExtensions).Assembly);
//env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesProcessingEntryPoints).Assembly);
//env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly);
env.ComponentCatalog.RegisterAssembly(typeof(SsaChangePointDetector).Assembly);
diff --git a/src/DotNetBridge/DotNetBridge.csproj b/src/DotNetBridge/DotNetBridge.csproj
index 67ba3209..18b73d32 100644
--- a/src/DotNetBridge/DotNetBridge.csproj
+++ b/src/DotNetBridge/DotNetBridge.csproj
@@ -38,6 +38,7 @@
+
diff --git a/src/DotNetBridge/Entrypoints.cs b/src/DotNetBridge/Entrypoints.cs
index 535d9d75..6f7f8d0c 100644
--- a/src/DotNetBridge/Entrypoints.cs
+++ b/src/DotNetBridge/Entrypoints.cs
@@ -178,5 +178,59 @@ public static ScoringTransformOutput Score(IHostEnvironment env, ScoringTransfor
};
}
+
+ public sealed class OnnxTransformInput : TransformInputBase
+ {
+ [Argument(ArgumentType.Required, HelpText = "Path to the onnx model file.", ShortName = "model", SortOrder = 0)]
+ public string ModelFile;
+
+ [Argument(ArgumentType.Multiple, HelpText = "Name of the input column.", SortOrder = 1)]
+ public string[] InputColumns;
+
+ [Argument(ArgumentType.Multiple, HelpText = "Name of the output column.", SortOrder = 2)]
+ public string[] OutputColumns;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "GPU device id to run on (e.g. 0,1,..). Null for CPU. Requires CUDA 9.1.", SortOrder = 3)]
+ public int? GpuDeviceId = null;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "If true, resumes execution on CPU upon GPU error. If false, will raise the GPU execption.", SortOrder = 4)]
+ public bool FallbackToCpu = false;
+ }
+
+ public sealed class OnnxTransformOutput
+ {
+ [TlcModule.Output(Desc = "ONNX transformed dataset", SortOrder = 1)]
+ public IDataView OutputData;
+
+ [TlcModule.Output(Desc = "Transform model", SortOrder = 2)]
+ public TransformModel Model;
+ }
+
+ [TlcModule.EntryPoint(Name = "Models.OnnxTransformer",
+ Desc = "Applies an ONNX model to a dataset.",
+ UserName = "Onnx Transformer",
+ ShortName = "onnx-xf")]
+ public static OnnxTransformOutput ApplyOnnxModel(IHostEnvironment env, OnnxTransformInput input)
+ {
+ var host = EntryPointUtils.CheckArgsAndCreateHost(env, "OnnxTransform", input);
+
+ var inputColumns = input.InputColumns ?? (Array.Empty());
+ var outputColumns = input.OutputColumns ?? (Array.Empty());
+
+ var transformsCatalog = new TransformsCatalog(host);
+ var onnxScoringEstimator = OnnxCatalog.ApplyOnnxModel(transformsCatalog,
+ outputColumns,
+ inputColumns,
+ input.ModelFile,
+ input.GpuDeviceId,
+ input.FallbackToCpu);
+
+ var view = onnxScoringEstimator.Fit(input.Data).Transform(input.Data);
+ return new OnnxTransformOutput()
+ {
+ Model = new TransformModelImpl(host, view, input.Data),
+ OutputData = view
+ };
+ }
}
}
diff --git a/src/DotNetBridge/ManifestUtils.cs b/src/DotNetBridge/ManifestUtils.cs
index c01b8480..b566cf2f 100644
--- a/src/DotNetBridge/ManifestUtils.cs
+++ b/src/DotNetBridge/ManifestUtils.cs
@@ -43,6 +43,7 @@ public static class ManifestUtils
typeof(ImageLoadingTransformer),
typeof(SymbolicSgdLogisticRegressionBinaryTrainer),
typeof(OnnxContext),
+ typeof(OnnxExportExtensions),
typeof(SsaForecastingTransformer),
typeof(VariableColumnTransform),
typeof(DateTimeTransformer)
diff --git a/src/Platforms/build.csproj b/src/Platforms/build.csproj
index 6a0b7ab7..b4ff1889 100644
--- a/src/Platforms/build.csproj
+++ b/src/Platforms/build.csproj
@@ -17,6 +17,7 @@
+
diff --git a/src/python/nimbusml.pyproj b/src/python/nimbusml.pyproj
index c80553c0..65925b6c 100644
--- a/src/python/nimbusml.pyproj
+++ b/src/python/nimbusml.pyproj
@@ -322,6 +322,7 @@
+
@@ -359,6 +360,8 @@
+
+
@@ -672,6 +675,7 @@
+
@@ -747,6 +751,7 @@
+
diff --git a/src/python/nimbusml/base_predictor.py b/src/python/nimbusml/base_predictor.py
index e619b115..1c694312 100644
--- a/src/python/nimbusml/base_predictor.py
+++ b/src/python/nimbusml/base_predictor.py
@@ -178,7 +178,6 @@ def summary(self):
self.model_summary_ = pipeline.summary()
return self.model_summary_
- @trace
def _get_implicit_transforms(
self,
features,
@@ -354,3 +353,20 @@ def _get_graph_nodes(
row_group_column_name=group_id_column)
graph_nodes['learner_node'] = [learner_node]
return graph_nodes, learner_features
+
+ @trace
+ def export_to_onnx(self, *args, **kwargs):
+ """
+ Export the model to the ONNX format.
+
+ See :py:meth:`nimbusml.Pipeline.export_to_onnx` for accepted arguments.
+ """
+ if not hasattr(self, 'model_') \
+ or self.model_ is None \
+ or not os.path.isfile(self.model_):
+
+ raise ValueError("Model is not fitted. Train or load a model before "
+ "export_to_onnx().")
+
+ pipeline = Pipeline([self], model=self.model_)
+ pipeline.export_to_onnx(*args, **kwargs)
diff --git a/src/python/nimbusml/base_transform.py b/src/python/nimbusml/base_transform.py
index b227d567..f0c4f861 100644
--- a/src/python/nimbusml/base_transform.py
+++ b/src/python/nimbusml/base_transform.py
@@ -124,3 +124,20 @@ def transform(self, X, as_binary_data_stream=False, **params):
data = pipeline.transform(
X, as_binary_data_stream=as_binary_data_stream, **params)
return data
+
+ @trace
+ def export_to_onnx(self, *args, **kwargs):
+ """
+ Export the model to the ONNX format.
+
+ See :py:meth:`nimbusml.Pipeline.export_to_onnx` for accepted arguments.
+ """
+ if not hasattr(self, 'model_') \
+ or self.model_ is None \
+ or not os.path.isfile(self.model_):
+
+ raise ValueError("Model is not fitted. Train or load a model before "
+ "export_to_onnx().")
+
+ pipeline = Pipeline([self], model=self.model_)
+ pipeline.export_to_onnx(*args, **kwargs)
diff --git a/src/python/nimbusml/internal/core/preprocessing/onnxrunner.py b/src/python/nimbusml/internal/core/preprocessing/onnxrunner.py
new file mode 100644
index 00000000..34ed46ba
--- /dev/null
+++ b/src/python/nimbusml/internal/core/preprocessing/onnxrunner.py
@@ -0,0 +1,71 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------------------------
+# - Generated by tools/entrypoint_compiler.py: do not edit by hand
+"""
+OnnxRunner
+"""
+
+__all__ = ["OnnxRunner"]
+
+
+from ...entrypoints.models_onnxtransformer import models_onnxtransformer
+from ...utils.utils import trace
+from ..base_pipeline_item import BasePipelineItem, DefaultSignature
+
+
+class OnnxRunner(BasePipelineItem, DefaultSignature):
+ """
+ **Description**
+ Applies an ONNX model to a dataset.
+
+ :param model_file: Path to the onnx model file.
+
+ :param input_columns: Name of the input column.
+
+ :param output_columns: Name of the output column.
+
+ :param gpu_device_id: GPU device id to run on (e.g. 0,1,..). Null for CPU.
+ Requires CUDA 9.1.
+
+ :param fallback_to_cpu: If true, resumes execution on CPU upon GPU error.
+ If false, will raise the GPU execption.
+
+ :param params: Additional arguments sent to compute engine.
+
+ """
+
+ @trace
+ def __init__(
+ self,
+ model_file,
+ input_columns=None,
+ output_columns=None,
+ gpu_device_id=None,
+ fallback_to_cpu=False,
+ **params):
+ BasePipelineItem.__init__(
+ self, type='transform', **params)
+
+ self.model_file = model_file
+ self.input_columns = input_columns
+ self.output_columns = output_columns
+ self.gpu_device_id = gpu_device_id
+ self.fallback_to_cpu = fallback_to_cpu
+
+ @property
+ def _entrypoint(self):
+ return models_onnxtransformer
+
+ @trace
+ def _get_node(self, **all_args):
+ algo_args = dict(
+ model_file=self.model_file,
+ input_columns=self.input_columns,
+ output_columns=self.output_columns,
+ gpu_device_id=self.gpu_device_id,
+ fallback_to_cpu=self.fallback_to_cpu)
+
+ all_args.update(algo_args)
+ return self._entrypoint(**all_args)
diff --git a/src/python/nimbusml/internal/entrypoints/models_onnxconverter.py b/src/python/nimbusml/internal/entrypoints/models_onnxconverter.py
new file mode 100644
index 00000000..3c080eb6
--- /dev/null
+++ b/src/python/nimbusml/internal/entrypoints/models_onnxconverter.py
@@ -0,0 +1,116 @@
+# - Generated by tools/entrypoint_compiler.py: do not edit by hand
+"""
+Models.OnnxConverter
+"""
+
+
+from ..utils.entrypoints import EntryPoint
+from ..utils.utils import try_set, unlist
+
+
+def models_onnxconverter(
+ onnx,
+ data_file=None,
+ json=None,
+ name=None,
+ domain=None,
+ inputs_to_drop=None,
+ outputs_to_drop=None,
+ model=None,
+ onnx_version='Stable',
+ predictive_model=None,
+ **params):
+ """
+ **Description**
+ Converts the model to ONNX format.
+
+ :param data_file: The data file (inputs).
+ :param onnx: The path to write the output ONNX to. (inputs).
+ :param json: The path to write the output JSON to. (inputs).
+ :param name: The 'name' property in the output ONNX. By default
+ this will be the ONNX extension-less name. (inputs).
+ :param domain: The 'domain' property in the output ONNX.
+ (inputs).
+ :param inputs_to_drop: Array of input column names to drop
+ (inputs).
+ :param outputs_to_drop: Array of output column names to drop
+ (inputs).
+ :param model: Model that needs to be converted to ONNX format.
+ (inputs).
+ :param onnx_version: The targeted ONNX version. It can be either
+ "Stable" or "Experimental". If "Experimental" is used,
+ produced model can contain components that is not officially
+ supported in ONNX standard. (inputs).
+ :param predictive_model: Predictor model that needs to be
+ converted to ONNX format. (inputs).
+ """
+
+ entrypoint_name = 'Models.OnnxConverter'
+ inputs = {}
+ outputs = {}
+
+ if data_file is not None:
+ inputs['DataFile'] = try_set(
+ obj=data_file,
+ none_acceptable=True,
+ is_of_type=str)
+ if onnx is not None:
+ inputs['Onnx'] = try_set(
+ obj=onnx,
+ none_acceptable=False,
+ is_of_type=str)
+ if json is not None:
+ inputs['Json'] = try_set(
+ obj=json,
+ none_acceptable=True,
+ is_of_type=str)
+ if name is not None:
+ inputs['Name'] = try_set(
+ obj=name,
+ none_acceptable=True,
+ is_of_type=str,
+ is_column=True)
+ if domain is not None:
+ inputs['Domain'] = try_set(
+ obj=domain,
+ none_acceptable=True,
+ is_of_type=str)
+ if inputs_to_drop is not None:
+ inputs['InputsToDrop'] = try_set(
+ obj=inputs_to_drop,
+ none_acceptable=True,
+ is_of_type=list)
+ if outputs_to_drop is not None:
+ inputs['OutputsToDrop'] = try_set(
+ obj=outputs_to_drop,
+ none_acceptable=True,
+ is_of_type=list)
+ if model is not None:
+ inputs['Model'] = try_set(
+ obj=model,
+ none_acceptable=True,
+ is_of_type=str)
+ if onnx_version is not None:
+ inputs['OnnxVersion'] = try_set(
+ obj=onnx_version,
+ none_acceptable=True,
+ is_of_type=str,
+ values=[
+ 'Stable',
+ 'Experimental'])
+ if predictive_model is not None:
+ inputs['PredictiveModel'] = try_set(
+ obj=predictive_model, none_acceptable=True, is_of_type=str)
+
+ input_variables = {
+ x for x in unlist(inputs.values())
+ if isinstance(x, str) and x.startswith("$")}
+ output_variables = {
+ x for x in unlist(outputs.values())
+ if isinstance(x, str) and x.startswith("$")}
+
+ entrypoint = EntryPoint(
+ name=entrypoint_name, inputs=inputs, outputs=outputs,
+ input_variables=input_variables,
+ output_variables=output_variables)
+ return entrypoint
diff --git a/src/python/nimbusml/internal/entrypoints/models_onnxtransformer.py b/src/python/nimbusml/internal/entrypoints/models_onnxtransformer.py
new file mode 100644
index 00000000..173c976a
--- /dev/null
+++ b/src/python/nimbusml/internal/entrypoints/models_onnxtransformer.py
@@ -0,0 +1,96 @@
+# - Generated by tools/entrypoint_compiler.py: do not edit by hand
+"""
+Models.OnnxTransformer
+"""
+
+import numbers
+
+from ..utils.entrypoints import EntryPoint
+from ..utils.utils import try_set, unlist
+
+
+def models_onnxtransformer(
+ model_file,
+ data,
+ output_data=None,
+ model=None,
+ input_columns=None,
+ output_columns=None,
+ gpu_device_id=None,
+ fallback_to_cpu=False,
+ **params):
+ """
+ **Description**
+ Applies an ONNX model to a dataset.
+
+ :param model_file: Path to the onnx model file. (inputs).
+ :param input_columns: Name of the input column. (inputs).
+ :param data: Input dataset (inputs).
+ :param output_columns: Name of the output column. (inputs).
+ :param gpu_device_id: GPU device id to run on (e.g. 0,1,..). Null
+ for CPU. Requires CUDA 9.1. (inputs).
+ :param fallback_to_cpu: If true, resumes execution on CPU upon
+ GPU error. If false, will raise the GPU execption. (inputs).
+ :param output_data: ONNX transformed dataset (outputs).
+ :param model: Transform model (outputs).
+ """
+
+ entrypoint_name = 'Models.OnnxTransformer'
+ inputs = {}
+ outputs = {}
+
+ if model_file is not None:
+ inputs['ModelFile'] = try_set(
+ obj=model_file,
+ none_acceptable=False,
+ is_of_type=str)
+ if input_columns is not None:
+ inputs['InputColumns'] = try_set(
+ obj=input_columns,
+ none_acceptable=True,
+ is_of_type=list,
+ is_column=True)
+ if data is not None:
+ inputs['Data'] = try_set(
+ obj=data,
+ none_acceptable=False,
+ is_of_type=str)
+ if output_columns is not None:
+ inputs['OutputColumns'] = try_set(
+ obj=output_columns,
+ none_acceptable=True,
+ is_of_type=list,
+ is_column=True)
+ if gpu_device_id is not None:
+ inputs['GpuDeviceId'] = try_set(
+ obj=gpu_device_id,
+ none_acceptable=True,
+ is_of_type=numbers.Real)
+ if fallback_to_cpu is not None:
+ inputs['FallbackToCpu'] = try_set(
+ obj=fallback_to_cpu,
+ none_acceptable=True,
+ is_of_type=bool)
+ if output_data is not None:
+ outputs['OutputData'] = try_set(
+ obj=output_data,
+ none_acceptable=False,
+ is_of_type=str)
+ if model is not None:
+ outputs['Model'] = try_set(
+ obj=model,
+ none_acceptable=False,
+ is_of_type=str)
+
+ input_variables = {
+ x for x in unlist(inputs.values())
+ if isinstance(x, str) and x.startswith("$")}
+ output_variables = {
+ x for x in unlist(outputs.values())
+ if isinstance(x, str) and x.startswith("$")}
+
+ entrypoint = EntryPoint(
+ name=entrypoint_name, inputs=inputs, outputs=outputs,
+ input_variables=input_variables,
+ output_variables=output_variables)
+ return entrypoint
diff --git a/src/python/nimbusml/pipeline.py b/src/python/nimbusml/pipeline.py
index 6eb190e3..d7337d37 100644
--- a/src/python/nimbusml/pipeline.py
+++ b/src/python/nimbusml/pipeline.py
@@ -39,6 +39,7 @@
from .internal.entrypoints.models_regressionevaluator import \
models_regressionevaluator
from .internal.entrypoints.models_summarizer import models_summarizer
+from .internal.entrypoints.models_onnxconverter import models_onnxconverter
from .internal.entrypoints.models_schema import models_schema
from .internal.entrypoints.transforms_datasetscorerex import \
transforms_datasetscorerex
@@ -2519,6 +2520,96 @@ def __setstate__(self, state):
else:
raise ValueError('Pipeline version not supported.')
+ @trace
+ def export_to_onnx(self,
+ dst,
+ domain,
+ dst_json=None,
+ name=None,
+ data_file=None,
+ inputs_to_drop=None,
+ outputs_to_drop=None,
+ onnx_version="Stable",
+ verbose=0):
+ """
+ Export the model to the ONNX format.
+
+ :param str dst: The path to write the output ONNX to.
+ :param str domain: A reverse-DNS name to indicate the model
+ namespace or domain, for example, 'org.onnx'.
+ :param str dst_json: The path to write the output ONNX to
+ in JSON format.
+ :param name: The 'graph.name' property in the output ONNX. By default
+ this will be the ONNX extension-less name. (inputs).
+ :param data_file: The data file (inputs).
+ :param inputs_to_drop: Array of input column names to drop
+ (inputs).
+ :param outputs_to_drop: Array of output column names to drop
+ (inputs).
+ :param onnx_version: The targeted ONNX version. It can be either
+ "Stable" or "Experimental". If "Experimental" is used,
+ produced model can contain components that is not officially
+ supported in ONNX standard. (inputs).
+ """
+ if not domain:
+ raise ValueError("domain argument must be specified and not empty.")
+
+ if not self._is_fitted:
+ raise ValueError("Model is not fitted. Train or load a model before "
+ "export_to_onnx().")
+
+ # start the clock!
+ start_time = time.time()
+
+ onnx_converter_args = {
+ 'onnx': dst,
+ 'json': dst_json,
+ 'domain': domain,
+ 'name': name,
+ 'data_file': data_file,
+ 'inputs_to_drop': inputs_to_drop,
+ 'outputs_to_drop': outputs_to_drop,
+ 'onnx_version': onnx_version
+ }
+
+ if (len(self.steps) > 0) and (self.last_node.type != "transform"):
+ onnx_converter_args['predictive_model'] = "$model"
+ else:
+ onnx_converter_args['model'] = "$model"
+
+ onnx_converter_node = models_onnxconverter(**onnx_converter_args)
+
+ inputs = dict([('model', self.model)])
+ outputs = dict()
+
+ graph = Graph(
+ inputs,
+ outputs,
+ False,
+ onnx_converter_node)
+
+ class_name = type(self).__name__
+ method_name = inspect.currentframe().f_code.co_name
+ telemetry_info = ".".join([class_name, method_name])
+
+ try:
+ graph.run(
+ X=None,
+ y=None,
+ random_state=self.random_state,
+ model=self.model,
+ verbose=verbose,
+ is_summary=False,
+ no_input_data=True,
+ telemetry_info=telemetry_info)
+ except RuntimeError as e:
+ self._run_time = time.time() - start_time
+ raise e
+
+ # stop the clock
+ self._run_time = time.time() - start_time
+ self._write_csv_time = graph._write_csv_time
+
@trace
def score(
self,
diff --git a/src/python/nimbusml/preprocessing/__init__.py b/src/python/nimbusml/preprocessing/__init__.py
index 728327be..202eb15d 100644
--- a/src/python/nimbusml/preprocessing/__init__.py
+++ b/src/python/nimbusml/preprocessing/__init__.py
@@ -2,6 +2,7 @@
from .tokey import ToKey
from .tensorflowscorer import TensorFlowScorer
from .datasettransformer import DatasetTransformer
+from .onnxrunner import OnnxRunner
from .datetimesplitter import DateTimeSplitter
from .tokeyimputer import ToKeyImputer
from .tostring import ToString
@@ -13,5 +14,6 @@
'ToKeyImputer',
'ToString',
'TensorFlowScorer',
- 'DatasetTransformer'
+ 'DatasetTransformer',
+ 'OnnxRunner'
]
diff --git a/src/python/nimbusml/preprocessing/onnxrunner.py b/src/python/nimbusml/preprocessing/onnxrunner.py
new file mode 100644
index 00000000..2df2ac75
--- /dev/null
+++ b/src/python/nimbusml/preprocessing/onnxrunner.py
@@ -0,0 +1,82 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------------------------
+# - Generated by tools/entrypoint_compiler.py: do not edit by hand
+"""
+OnnxRunner
+"""
+
+__all__ = ["OnnxRunner"]
+
+
+from sklearn.base import TransformerMixin
+
+from ..base_transform import BaseTransform
+from ..internal.core.preprocessing.onnxrunner import OnnxRunner as core
+from ..internal.utils.utils import trace
+
+
+class OnnxRunner(core, BaseTransform, TransformerMixin):
+ """
+ **Description**
+ Applies an ONNX model to a dataset.
+
+ :param columns: see `Columns `_.
+
+ :param model_file: Path to the onnx model file.
+
+ :param input_columns: Name of the input column.
+
+ :param output_columns: Name of the output column.
+
+ :param gpu_device_id: GPU device id to run on (e.g. 0,1,..). Null for CPU.
+ Requires CUDA 9.1.
+
+ :param fallback_to_cpu: If true, resumes execution on CPU upon GPU error.
+ If false, will raise the GPU execption.
+
+ :param params: Additional arguments sent to compute engine.
+
+ """
+
+ @trace
+ def __init__(
+ self,
+ model_file,
+ input_columns=None,
+ output_columns=None,
+ gpu_device_id=None,
+ fallback_to_cpu=False,
+ columns=None,
+ **params):
+
+ if columns:
+ params['columns'] = columns
+ if columns:
+ input_columns = sum(
+ list(
+ columns.values()),
+ []) if isinstance(
+ list(
+ columns.values())[0],
+ list) else list(
+ columns.values())
+ if columns:
+ output_columns = list(columns.keys())
+ BaseTransform.__init__(self, **params)
+ core.__init__(
+ self,
+ model_file=model_file,
+ input_columns=input_columns,
+ output_columns=output_columns,
+ gpu_device_id=gpu_device_id,
+ fallback_to_cpu=fallback_to_cpu,
+ **params)
+ self._columns = columns
+
+ def get_params(self, deep=False):
+ """
+ Get the parameters for this operator.
+ """
+ return core.get_params(self)
diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py
index 18a5ba31..998585f8 100644
--- a/src/python/tests/test_estimator_checks.py
+++ b/src/python/tests/test_estimator_checks.py
@@ -280,6 +280,7 @@
# skip SymSgdBinaryClassifier for now, because of crashes.
'SymSgdBinaryClassifier',
'DatasetTransformer',
+ 'OnnxRunner'
'TimeSeriesImputer'
])
diff --git a/src/python/tests_extended/test_export_to_onnx.py b/src/python/tests_extended/test_export_to_onnx.py
new file mode 100644
index 00000000..3cf5fdb4
--- /dev/null
+++ b/src/python/tests_extended/test_export_to_onnx.py
@@ -0,0 +1,437 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------------------------
+"""
+Verify onnx export support
+"""
+import contextlib
+import io
+import json
+import os
+import sys
+import tempfile
+import numpy as np
+import pandas as pd
+import pprint
+
+from nimbusml import Pipeline
+from nimbusml.cluster import KMeansPlusPlus
+from nimbusml.datasets import get_dataset
+from nimbusml.datasets.image import get_RevolutionAnalyticslogo, get_Microsoftlogo
+from nimbusml.decomposition import PcaTransformer, PcaAnomalyDetector
+from nimbusml.ensemble import FastForestBinaryClassifier, FastTreesTweedieRegressor, LightGbmRanker
+from nimbusml.feature_extraction.categorical import OneHotVectorizer, OneHotHashVectorizer
+from nimbusml.feature_extraction.image import Loader, Resizer, PixelExtractor
+from nimbusml.feature_extraction.text import NGramFeaturizer
+from nimbusml.feature_extraction.text.extractor import Ngram
+from nimbusml.feature_selection import CountSelector, MutualInformationSelector
+from nimbusml.linear_model import FastLinearBinaryClassifier
+from nimbusml.naive_bayes import NaiveBayesClassifier
+from nimbusml.preprocessing import TensorFlowScorer, FromKey, ToKey
+from nimbusml.preprocessing.filter import SkipFilter, TakeFilter, RangeFilter
+from nimbusml.preprocessing.missing_values import Filter, Handler, Indicator
+from nimbusml.preprocessing.normalization import Binner, GlobalContrastRowScaler, LpScaler
+from nimbusml.preprocessing.schema import (ColumnConcatenator, TypeConverter,
+ ColumnDuplicator, ColumnSelector)
+from nimbusml.preprocessing.text import CharTokenizer, WordTokenizer
+from nimbusml.timeseries import (IidSpikeDetector, IidChangePointDetector,
+ SsaSpikeDetector, SsaChangePointDetector,
+ SsaForecaster)
+
+
+SHOW_ONNX_JSON = False
+
+script_path = os.path.realpath(__file__)
+script_dir = os.path.dirname(script_path)
+
+# Sepal_Length Sepal_Width Petal_Length Petal_Width Label Species Setosa
+# 0 5.1 3.5 1.4 0.2 0 setosa 1.0
+# 1 4.9 3.0 1.4 0.2 0 setosa 1.0
+iris_df = get_dataset("iris").as_df()
+iris_df.drop(['Species'], axis=1, inplace=True)
+
+iris_no_label_df = iris_df.drop(['Label'], axis=1)
+iris_binary_df = iris_no_label_df.rename(columns={'Setosa': 'Label'})
+iris_regression_df = iris_no_label_df.drop(['Setosa'], axis=1).rename(columns={'Petal_Width': 'Label'})
+
+# Unnamed: 0 education age parity induced case spontaneous stratum pooled.stratum education_str
+# 0 1 0.0 26.0 6.0 1.0 1.0 2.0 1.0 3.0 0-5yrs
+# 1 2 0.0 42.0 1.0 1.0 1.0 0.0 2.0 1.0 0-5yrs
+infert_df = get_dataset("infert").as_df()
+infert_df.columns = [i.replace(': ', '') for i in infert_df.columns]
+infert_df.rename(columns={'case': 'Label'}, inplace=True)
+
+infert_onehot_df = (OneHotVectorizer() << 'education_str').fit_transform(infert_df)
+
+# rank group carrier price Class dep_day nbr_stops duration
+# 0 2 1 AA 240 3 1 0 12.0
+# 1 1 1 AA 300 3 0 1 15.0
+file_path = get_dataset("gen_tickettrain").as_filepath()
+gen_tt_df = pd.read_csv(file_path)
+gen_tt_df['group'] = gen_tt_df['group'].astype(np.uint32)
+
+# Unnamed: 0 Label Solar_R Wind Temp Month Day
+# 0 1 41.0 190.0 7.4 67 5 1
+# 1 2 36.0 118.0 8.0 72 5 2
+airquality_df = get_dataset("airquality").as_df().fillna(0)
+airquality_df = airquality_df[airquality_df.Ozone.notnull()]
+
+# Sentiment SentimentText
+# 0 1 ==RUDE== Dude, you are rude upload that carl...
+# 1 1 == OK! == IM GOING TO VANDALIZE WILD ONES W...
+file_path = get_dataset("wiki_detox_train").as_filepath()
+wiki_detox_df = pd.read_csv(file_path, sep='\t')
+wiki_detox_df = wiki_detox_df.head(10)
+
+# Path Label
+# 0 C:\repo\src\python... True
+# 1 C:\repo\src\python... False
+image_paths_df = pd.DataFrame(data=dict(
+ Path=[get_RevolutionAnalyticslogo(), get_Microsoftlogo()],
+ Label=[True, False]))
+
+
+SKIP = {
+ 'DatasetTransformer',
+ 'LightLda',
+ 'OneVsRestClassifier',
+ 'Sentiment',
+ 'TensorFlowScorer',
+ 'TreeFeaturizer',
+ 'WordEmbedding',
+ 'OnnxRunner'
+}
+
+INSTANCES = {
+ 'Binner': Binner(num_bins=3),
+ 'CharTokenizer': CharTokenizer(columns={'SentimentText_Transform': 'SentimentText'}),
+ 'ColumnConcatenator': ColumnConcatenator(columns={'Features': [
+ 'Sepal_Length',
+ 'Sepal_Width',
+ 'Petal_Length',
+ 'Petal_Width',
+ 'Setosa']}),
+ 'ColumnSelector': ColumnSelector(columns=['Sepal_Width', 'Sepal_Length']),
+ 'ColumnDuplicator': ColumnDuplicator(columns={'dup': 'Sepal_Width'}),
+ 'CountSelector': CountSelector(count=5, columns=['Sepal_Width']),
+ 'FastForestBinaryClassifier': FastForestBinaryClassifier(feature=['Sepal_Width', 'Sepal_Length'],
+ label='Setosa'),
+ 'FastLinearBinaryClassifier': FastLinearBinaryClassifier(feature=['Sepal_Width', 'Sepal_Length'],
+ label='Setosa'),
+ 'FastTreesTweedieRegressor': FastTreesTweedieRegressor(label='Ozone'),
+ 'Filter': Filter(columns=[ 'Petal_Length', 'Petal_Width']),
+ 'FromKey': Pipeline([
+ ToKey(columns=['Setosa']),
+ FromKey(columns=['Setosa'])
+ ]),
+ # GlobalContrastRowScaler currently requires a vector input to work
+ 'GlobalContrastRowScaler': Pipeline([
+ ColumnConcatenator() << {
+ 'concated_columns': [
+ 'Petal_Length',
+ 'Sepal_Width',
+ 'Sepal_Length']},
+ GlobalContrastRowScaler(columns={'normed_columns': 'concated_columns'})
+ ]),
+ 'Handler': Handler(replace_with='Mean', columns={'NewVals': 'Sepal_Length'}),
+ 'IidSpikeDetector': IidSpikeDetector(columns=['Sepal_Length']),
+ 'IidChangePointDetector': IidChangePointDetector(columns=['Sepal_Length']),
+ 'Indicator': Indicator(columns={'Has_Nan': 'Petal_Length'}),
+ 'KMeansPlusPlus': KMeansPlusPlus(n_clusters=3, feature=['Sepal_Width', 'Sepal_Length']),
+ 'LightGbmRanker': LightGbmRanker(feature=['Class', 'dep_day', 'duration'],
+ label='rank',
+ group_id='group'),
+ 'Loader': Loader(columns={'ImgPath': 'Path'}),
+ 'LpScaler': Pipeline([
+ ColumnConcatenator() << {
+ 'concated_columns': [
+ 'Petal_Length',
+ 'Sepal_Width',
+ 'Sepal_Length']},
+ LpScaler(columns={'normed_columns': 'concated_columns'})
+ ]),
+ 'MutualInformationSelector': Pipeline([
+ ColumnConcatenator(columns={'Features': ['Sepal_Width', 'Sepal_Length', 'Petal_Width']}),
+ MutualInformationSelector(
+ columns='Features',
+ label='Label',
+ slots_in_output=2) # only accept one column
+ ]),
+ 'NaiveBayesClassifier': NaiveBayesClassifier(feature=['Sepal_Width', 'Sepal_Length']),
+ 'NGramFeaturizer': NGramFeaturizer(word_feature_extractor=Ngram(),
+ columns={ 'features': ['SentimentText']}),
+ 'OneHotHashVectorizer': OneHotHashVectorizer(columns=['education_str']),
+ 'OneHotVectorizer': OneHotVectorizer(columns=['education_str']),
+ 'PcaAnomalyDetector': PcaAnomalyDetector(rank=3),
+ 'PcaTransformer': PcaTransformer(rank=2),
+ 'PixelExtractor': Pipeline([
+ Loader(columns={'ImgPath': 'Path'}),
+ PixelExtractor(columns={'ImgPixels': 'ImgPath'}),
+ ]),
+ 'Resizer': Pipeline([
+ Loader(columns={'ImgPath': 'Path'}),
+ Resizer(image_width=227, image_height=227,
+ columns={'ImgResize': 'ImgPath'})
+ ]),
+ 'SkipFilter': SkipFilter(count=5),
+ 'SsaSpikeDetector': SsaSpikeDetector(columns=['Sepal_Length'],
+ seasonal_window_size=2),
+ 'SsaChangePointDetector': SsaChangePointDetector(columns=['Sepal_Length'],
+ seasonal_window_size=2),
+ 'SsaForecaster': SsaForecaster(columns=['Sepal_Length'],
+ window_size=2,
+ series_length=5,
+ train_size=5,
+ horizon=1),
+ 'RangeFilter': RangeFilter(min=5.0, max=5.1, columns=['Sepal_Length']),
+ 'TakeFilter': TakeFilter(count=100),
+ 'TensorFlowScorer': TensorFlowScorer(
+ model_location=os.path.join(
+ script_dir,
+ '..',
+ 'nimbusml',
+ 'examples',
+ 'frozen_saved_model.pb'),
+ columns={'c': ['a', 'b']}),
+ 'ToKey': ToKey(columns={'edu_1': 'education'}),
+ 'TypeConverter': TypeConverter(columns=['age'], result_type='R4'),
+ 'WordTokenizer': WordTokenizer(char_array_term_separators=[" "]) << {'wt': 'SentimentText'}
+}
+
+DATASETS = {
+ 'AveragedPerceptronBinaryClassifier': infert_onehot_df,
+ 'Binner': iris_no_label_df,
+ 'BootstrapSampler': infert_df,
+ 'CharTokenizer': wiki_detox_df,
+ 'EnsembleRegressor': iris_regression_df,
+ 'FactorizationMachineBinaryClassifier': iris_binary_df,
+ 'FastForestBinaryClassifier': iris_no_label_df,
+ 'FastForestRegressor': iris_regression_df,
+ 'FastLinearBinaryClassifier': iris_no_label_df,
+ 'FastLinearClassifier': iris_binary_df,
+ 'FastLinearRegressor': iris_regression_df,
+ 'FastTreesBinaryClassifier': iris_binary_df,
+ 'FastTreesRegressor': iris_regression_df,
+ 'FastTreesTweedieRegressor': airquality_df,
+ 'Filter': iris_no_label_df,
+ 'GamBinaryClassifier': iris_binary_df,
+ 'GamRegressor': iris_regression_df,
+ 'GlobalContrastRowScaler': iris_df.astype(np.float32),
+ 'LightGbmRanker': gen_tt_df,
+ 'LinearSvmBinaryClassifier': iris_binary_df,
+ 'Loader': image_paths_df,
+ 'LogisticRegressionBinaryClassifier': iris_binary_df,
+ 'LogisticRegressionClassifier': iris_binary_df,
+ 'LogMeanVarianceScaler': iris_no_label_df,
+ 'LpScaler': iris_no_label_df.drop(['Setosa'], axis=1).astype(np.float32),
+ 'MeanVarianceScaler': iris_no_label_df,
+ 'MinMaxScaler': iris_no_label_df,
+ 'NGramFeaturizer': wiki_detox_df,
+ 'OneHotHashVectorizer': infert_df,
+ 'OneHotVectorizer': infert_df,
+ 'OnlineGradientDescentRegressor': iris_regression_df,
+ 'OrdinaryLeastSquaresRegressor': iris_regression_df,
+ 'PcaAnomalyDetector': iris_no_label_df,
+ 'PcaTransformer': iris_regression_df,
+ 'PixelExtractor': image_paths_df,
+ 'PoissonRegressionRegressor': iris_regression_df,
+ 'Resizer': image_paths_df,
+ 'SgdBinaryClassifier': iris_binary_df,
+ 'SymSgdBinaryClassifier': iris_binary_df,
+ 'ToKey': infert_df,
+ 'TypeConverter': infert_onehot_df,
+ 'WordTokenizer': wiki_detox_df
+}
+
+REQUIRES_EXPERIMENTAL = {
+}
+
+SUPPORTED_ESTIMATORS = {
+ 'ColumnConcatenator',
+ 'ColumnDuplicator',
+ 'CountSelector',
+ 'EnsembleClassifier',
+ 'EnsembleRegressor',
+ 'FastForestRegressor',
+ 'FastLinearRegressor',
+ 'FastTreesRegressor',
+ 'FastTreesTweedieRegressor',
+ 'GamRegressor',
+ 'Indicator',
+ 'KMeansPlusPlus',
+ 'LightGbmBinaryClassifier',
+ 'LightGbmClassifier',
+ 'LightGbmRegressor',
+ 'LpScaler',
+ 'MeanVarianceScaler',
+ 'MinMaxScaler',
+ 'NaiveBayesClassifier',
+ 'OneHotVectorizer',
+ 'OnlineGradientDescentRegressor',
+ 'OrdinaryLeastSquaresRegressor',
+ 'PcaAnomalyDetector',
+ 'PoissonRegressionRegressor',
+ 'PrefixColumnConcatenator',
+ 'TypeConverter',
+ 'WordTokenizer'
+}
+
+
+class CaptureOutputContext():
+ """
+ Context which can be used for
+ capturing stdout and stderr.
+ """
+ def __enter__(self):
+ self.orig_stdout = sys.stdout
+ self.orig_stderr = sys.stderr
+ self.stdout_capturer = io.StringIO()
+ self.stderr_capturer = io.StringIO()
+ sys.stdout = self.stdout_capturer
+ sys.stderr = self.stderr_capturer
+ return self
+
+ def __exit__(self, *args):
+ sys.stdout = self.orig_stdout
+ sys.stderr = self.orig_stderr
+ self.stdout = self.stdout_capturer.getvalue()
+ self.stderr = self.stderr_capturer.getvalue()
+
+ if self.stdout:
+ print(self.stdout)
+
+ if self.stderr:
+ print(self.stderr)
+
+ # free up some memory
+ del self.stdout_capturer
+ del self.stderr_capturer
+
+
+def get_tmp_file(suffix=None):
+ fd, file_name = tempfile.mkstemp(suffix=suffix)
+ fl = os.fdopen(fd, 'w')
+ fl.close()
+ return file_name
+
+
+def get_file_size(file_path):
+ file_size = 0
+ try:
+ file_size = os.path.getsize(file_path)
+ except:
+ pass
+ return file_size
+
+
+def load_json(file_path):
+ with open(file_path) as f:
+ lines = f.readlines()
+ lines = [l for l in lines if not l.strip().startswith('#')]
+ content_without_comments = '\n'.join(lines)
+ return json.loads(content_without_comments)
+
+
+def test_export_to_onnx(estimator, class_name):
+ """
+ Fit and test an estimator and determine
+ if it supports exporting to the ONNX format.
+ """
+ onnx_path = get_tmp_file('.onnx')
+ onnx_json_path = get_tmp_file('.onnx.json')
+
+ output = None
+ exported = False
+
+ try:
+ dataset = DATASETS.get(class_name, iris_df)
+ estimator.fit(dataset)
+
+ onnx_version = 'Experimental' if class_name in REQUIRES_EXPERIMENTAL else 'Stable'
+
+ with CaptureOutputContext() as output:
+ estimator.export_to_onnx(onnx_path,
+ 'com.microsoft.ml',
+ dst_json=onnx_json_path,
+ onnx_version=onnx_version)
+ except Exception as e:
+ print(e)
+
+ onnx_file_size = get_file_size(onnx_path)
+ onnx_json_file_size = get_file_size(onnx_json_path)
+
+ if (output and
+ (onnx_file_size != 0) and
+ (onnx_json_file_size != 0) and
+ (not 'cannot save itself as ONNX' in output.stdout)):
+ exported = True
+
+ if exported and SHOW_ONNX_JSON:
+ with open(onnx_json_path) as f:
+ print(json.dumps(json.load(f), indent=4))
+
+ os.remove(onnx_path)
+ os.remove(onnx_json_path)
+ return exported
+
+
+manifest_diff = os.path.join(script_dir, '..', 'tools', 'manifest_diff.json')
+entry_points = load_json(manifest_diff)['EntryPoints']
+entry_points = sorted(entry_points, key=lambda ep: ep['NewName'])
+
+exportable_estimators = set()
+exportable_experimental_estimators = set()
+unexportable_estimators = set()
+
+for entry_point in entry_points:
+ class_name = entry_point['NewName']
+
+ print('\n===========> %s' % class_name)
+
+ if class_name in SKIP:
+ print("skipped")
+ continue
+
+ mod = __import__('nimbusml.' + entry_point['Module'],
+ fromlist=[str(class_name)])
+
+ if class_name in INSTANCES:
+ estimator = INSTANCES[class_name]
+ else:
+ the_class = getattr(mod, class_name)
+ estimator = the_class()
+
+ result = test_export_to_onnx(estimator, class_name)
+
+ if result:
+ if class_name in REQUIRES_EXPERIMENTAL:
+ exportable_experimental_estimators.add(class_name)
+ else:
+ exportable_estimators.add(class_name)
+
+ print('Estimator successfully exported to ONNX.')
+
+ else:
+ unexportable_estimators.add(class_name)
+ print('Estimator could NOT be exported to ONNX.')
+
+print('\nThe following estimators were skipped: ')
+pprint.pprint(sorted(SKIP))
+
+print('\nThe following estimators were successfully exported to ONNX:')
+pprint.pprint(sorted(exportable_estimators))
+
+print('\nThe following estimators were successfully exported to experimental ONNX: ')
+pprint.pprint(sorted(exportable_experimental_estimators))
+
+print('\nThe following estimators could not be exported to ONNX: ')
+pprint.pprint(sorted(unexportable_estimators))
+
+failed_estimators = SUPPORTED_ESTIMATORS.difference(exportable_estimators) \
+ .difference(exportable_experimental_estimators)
+
+if len(failed_estimators) > 0:
+ print("The following tests failed exporting to onnx:", sorted(failed_estimators))
+ raise RuntimeError("onnx export checks failed")
+
diff --git a/src/python/tools/manifest.json b/src/python/tools/manifest.json
index 5b739e57..491ac877 100644
--- a/src/python/tools/manifest.json
+++ b/src/python/tools/manifest.json
@@ -2194,6 +2194,203 @@
"ITrainerInput"
]
},
+ {
+ "Name": "Models.OnnxConverter",
+ "Desc": "Converts the model to ONNX format.",
+ "FriendlyName": "ONNX Converter.",
+ "ShortName": null,
+ "Inputs": [
+ {
+ "Name": "DataFile",
+ "Type": "String",
+ "Desc": "The data file",
+ "Aliases": [
+ "data"
+ ],
+ "Required": false,
+ "SortOrder": 0.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Onnx",
+ "Type": "String",
+ "Desc": "The path to write the output ONNX to.",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Json",
+ "Type": "String",
+ "Desc": "The path to write the output JSON to.",
+ "Required": false,
+ "SortOrder": 2.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Name",
+ "Type": "String",
+ "Desc": "The 'name' property in the output ONNX. By default this will be the ONNX extension-less name.",
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Domain",
+ "Type": "String",
+ "Desc": "The 'domain' property in the output ONNX.",
+ "Required": false,
+ "SortOrder": 4.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "InputsToDrop",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "Array of input column names to drop",
+ "Required": false,
+ "SortOrder": 6.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "OutputsToDrop",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "Array of output column names to drop",
+ "Required": false,
+ "SortOrder": 8.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Model that needs to be converted to ONNX format.",
+ "Required": false,
+ "SortOrder": 10.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "OnnxVersion",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "Stable",
+ "Experimental"
+ ]
+ },
+ "Desc": "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.",
+ "Required": false,
+ "SortOrder": 11.0,
+ "IsNullable": false,
+ "Default": "Stable"
+ },
+ {
+ "Name": "PredictiveModel",
+ "Type": "PredictorModel",
+ "Desc": "Predictor model that needs to be converted to ONNX format.",
+ "Required": false,
+ "SortOrder": 12.0,
+ "IsNullable": false,
+ "Default": null
+ }
+ ],
+ "Outputs": []
+ },
+ {
+ "Name": "Models.OnnxTransformer",
+ "Desc": "Applies an ONNX model to a dataset.",
+ "FriendlyName": "Onnx Transformer",
+ "ShortName": "onnx-xf",
+ "Inputs": [
+ {
+ "Name": "ModelFile",
+ "Type": "String",
+ "Desc": "Path to the onnx model file.",
+ "Aliases": [
+ "model"
+ ],
+ "Required": true,
+ "SortOrder": 0.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "InputColumns",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "Name of the input column.",
+ "Required": false,
+ "SortOrder": 1.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "OutputColumns",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "Name of the output column.",
+ "Required": false,
+ "SortOrder": 2.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "GpuDeviceId",
+ "Type": "Int",
+ "Desc": "GPU device id to run on (e.g. 0,1,..). Null for CPU. Requires CUDA 9.1.",
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": true,
+ "Default": null
+ },
+ {
+ "Name": "FallbackToCpu",
+ "Type": "Bool",
+ "Desc": "If true, resumes execution on CPU upon GPU error. If false, will raise the GPU execption.",
+ "Required": false,
+ "SortOrder": 4.0,
+ "IsNullable": false,
+ "Default": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "ONNX transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ]
+ },
{
"Name": "Models.OvaModelCombiner",
"Desc": "Combines a sequence of PredictorModels into a single model",
diff --git a/src/python/tools/manifest_diff.json b/src/python/tools/manifest_diff.json
index c94a8845..a70489ee 100644
--- a/src/python/tools/manifest_diff.json
+++ b/src/python/tools/manifest_diff.json
@@ -347,6 +347,12 @@
"Module": "preprocessing",
"Type": "Transform"
},
+ {
+ "Name": "Models.OnnxTransformer",
+ "NewName": "OnnxRunner",
+ "Module": "preprocessing",
+ "Type": "Transform"
+ },
{
"Name": "Trainers.FieldAwareFactorizationMachineBinaryClassifier",
"NewName": "FactorizationMachineBinaryClassifier",