diff --git a/src/python/tests_extended/test_export_to_onnx.py b/src/python/tests_extended/test_export_to_onnx.py index 3cf5fdb4..61aca090 100644 --- a/src/python/tests_extended/test_export_to_onnx.py +++ b/src/python/tests_extended/test_export_to_onnx.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------------------------- """ -Verify onnx export support +Verify onnx export and transform support """ import contextlib import io @@ -16,6 +16,7 @@ import pprint from nimbusml import Pipeline +from nimbusml.base_predictor import BasePredictor from nimbusml.cluster import KMeansPlusPlus from nimbusml.datasets import get_dataset from nimbusml.datasets.image import get_RevolutionAnalyticslogo, get_Microsoftlogo @@ -28,7 +29,8 @@ from nimbusml.feature_selection import CountSelector, MutualInformationSelector from nimbusml.linear_model import FastLinearBinaryClassifier from nimbusml.naive_bayes import NaiveBayesClassifier -from nimbusml.preprocessing import TensorFlowScorer, FromKey, ToKey +from nimbusml.preprocessing import (TensorFlowScorer, FromKey, ToKey, + DateTimeSplitter, OnnxRunner) from nimbusml.preprocessing.filter import SkipFilter, TakeFilter, RangeFilter from nimbusml.preprocessing.missing_values import Filter, Handler, Indicator from nimbusml.preprocessing.normalization import Binner, GlobalContrastRowScaler, LpScaler @@ -41,6 +43,13 @@ SHOW_ONNX_JSON = False +SHOW_TRANSFORMED_RESULTS = True +SHOW_FULL_PANDAS_OUTPUT = False + +if SHOW_FULL_PANDAS_OUTPUT: + pd.set_option('display.max_columns', None) + pd.set_option('display.max_rows', None) + pd.set_option('display.width', 10000) script_path = os.path.realpath(__file__) script_dir = os.path.dirname(script_path) @@ -51,6 +60,9 @@ iris_df = get_dataset("iris").as_df() iris_df.drop(['Species'], axis=1, inplace=True) +iris_with_nan_df = iris_df.copy() +iris_with_nan_df.loc[1, 'Petal_Length'] = np.nan + iris_no_label_df = iris_df.drop(['Label'], axis=1) iris_binary_df = iris_no_label_df.rename(columns={'Setosa': 'Label'}) iris_regression_df = iris_no_label_df.drop(['Setosa'], axis=1).rename(columns={'Petal_Width': 'Label'}) @@ -95,12 +107,15 @@ SKIP = { 'DatasetTransformer', 'LightLda', + 'NGramExtractor', # Crashes + 'NGramFeaturizer', # Crashes 'OneVsRestClassifier', + 'OnnxRunner', 'Sentiment', 'TensorFlowScorer', + 'TimeSeriesImputer', 'TreeFeaturizer', 'WordEmbedding', - 'OnnxRunner' } INSTANCES = { @@ -115,6 +130,7 @@ 'ColumnSelector': ColumnSelector(columns=['Sepal_Width', 'Sepal_Length']), 'ColumnDuplicator': ColumnDuplicator(columns={'dup': 'Sepal_Width'}), 'CountSelector': CountSelector(count=5, columns=['Sepal_Width']), + 'DateTimeSplitter': DateTimeSplitter(prefix='dt'), 'FastForestBinaryClassifier': FastForestBinaryClassifier(feature=['Sepal_Width', 'Sepal_Length'], label='Setosa'), 'FastLinearBinaryClassifier': FastLinearBinaryClassifier(feature=['Sepal_Width', 'Sepal_Length'], @@ -122,8 +138,8 @@ 'FastTreesTweedieRegressor': FastTreesTweedieRegressor(label='Ozone'), 'Filter': Filter(columns=[ 'Petal_Length', 'Petal_Width']), 'FromKey': Pipeline([ - ToKey(columns=['Setosa']), - FromKey(columns=['Setosa']) + ToKey(columns=['Sepal_Length']), + FromKey(columns=['Sepal_Length']) ]), # GlobalContrastRowScaler currently requires a vector input to work 'GlobalContrastRowScaler': Pipeline([ @@ -134,7 +150,7 @@ 'Sepal_Length']}, GlobalContrastRowScaler(columns={'normed_columns': 'concated_columns'}) ]), - 'Handler': Handler(replace_with='Mean', columns={'NewVals': 'Sepal_Length'}), + 'Handler': Handler(replace_with='Mean', columns={'NewVals': 'Petal_Length'}), 'IidSpikeDetector': IidSpikeDetector(columns=['Sepal_Length']), 'IidChangePointDetector': IidChangePointDetector(columns=['Sepal_Length']), 'Indicator': Indicator(columns={'Has_Nan': 'Petal_Length'}), @@ -218,11 +234,13 @@ 'GamBinaryClassifier': iris_binary_df, 'GamRegressor': iris_regression_df, 'GlobalContrastRowScaler': iris_df.astype(np.float32), + 'Handler': iris_with_nan_df, + 'Indicator': iris_with_nan_df, 'LightGbmRanker': gen_tt_df, 'LinearSvmBinaryClassifier': iris_binary_df, 'Loader': image_paths_df, 'LogisticRegressionBinaryClassifier': iris_binary_df, - 'LogisticRegressionClassifier': iris_binary_df, + 'LogisticRegressionClassifier': iris_df, 'LogMeanVarianceScaler': iris_no_label_df, 'LpScaler': iris_no_label_df.drop(['Setosa'], axis=1).astype(np.float32), 'MeanVarianceScaler': iris_no_label_df, @@ -244,33 +262,107 @@ 'WordTokenizer': wiki_detox_df } +EXPECTED_RESULTS = { + 'AveragedPerceptronBinaryClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'CharTokenizer': {'cols': [('SentimentText_Transform.%03d' % i, 'SentimentText_Transform.%03d' % i) + for i in range(0, 422)]}, + 'ColumnDuplicator': {'cols': [('dup', 'dup.0')]}, + 'ColumnSelector': { + 'num_cols': 2, + 'cols': [('Sepal_Width', 'Sepal_Width.0'), ('Sepal_Length', 'Sepal_Length.0')] + }, + #'EnsembleClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + #'EnsembleRegressor': {'cols': [('Score', 'Score.0')]}, + 'FastForestBinaryClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'FastForestRegressor': {'cols': [('Score', 'Score.0')]}, + 'FastLinearBinaryClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'FastLinearClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'FastLinearRegressor': {'cols': [('Score', 'Score.0')]}, + 'FastTreesBinaryClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'FastTreesRegressor': {'cols': [('Score', 'Score.0')]}, + 'FastTreesTweedieRegressor': {'cols': [('Score', 'Score.0')]}, + 'FromKey': {'cols': [('Sepal_Length', 'Sepal_Length.0'), ('Label', 'Label.0')]}, + 'GlobalContrastRowScaler': {'cols': [ + ('normed_columns.Petal_Length', 'normed_columns.0'), + ('normed_columns.Sepal_Width', 'normed_columns.1'), + ('normed_columns.Sepal_Length', 'normed_columns.2') + ]}, + 'Handler': {'cols': [ + ('NewVals.NewVals', 'NewVals.0'), + ('NewVals.IsMissing.NewVals', 'NewVals.1') + ]}, + 'Indicator': {'cols': [('Has_Nan', 'Has_Nan.0')]}, + 'KMeansPlusPlus': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'LightGbmBinaryClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'LightGbmClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'LightGbmRanker': {'cols': [('Score', 'Score.0')]}, + 'LightGbmRegressor': {'cols': [('Score', 'Score.0')]}, + 'LinearSvmBinaryClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'LogisticRegressionBinaryClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'LogisticRegressionClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'LpScaler': {'cols': [ + ('normed_columns.Petal_Length', 'normed_columns.0'), + ('normed_columns.Sepal_Width', 'normed_columns.1'), + ('normed_columns.Sepal_Length', 'normed_columns.2') + ]}, + 'MeanVarianceScaler': {'cols': list(zip( + ['Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width', 'Setosa'], + ['Sepal_Length.0', 'Sepal_Width.0', 'Petal_Length.0', 'Petal_Width.0', 'Setosa.0'] + ))}, + 'MinMaxScaler': {'cols': list(zip( + ['Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width', 'Setosa'], + ['Sepal_Length.0', 'Sepal_Width.0', 'Petal_Length.0', 'Petal_Width.0', 'Setosa.0'] + ))}, + #'MutualInformationSelector', + 'NaiveBayesClassifier': {'cols': [('PredictedLabel', 'PredictedLabel.0')]}, + 'OneHotVectorizer': {'cols': list(zip( + ['education_str.0-5yrs', 'education_str.6-11yrs', 'education_str.12+ yrs'], + ['education_str.0', 'education_str.1', 'education_str.2'] + ))}, + 'OnlineGradientDescentRegressor': {'cols': [('Score', 'Score.0')]}, + 'OrdinaryLeastSquaresRegressor': {'cols': [('Score', 'Score.0')]}, +} + REQUIRES_EXPERIMENTAL = { } SUPPORTED_ESTIMATORS = { + 'AveragedPerceptronBinaryClassifier', + 'CharTokenizer', 'ColumnConcatenator', 'ColumnDuplicator', + 'ColumnSelector', 'CountSelector', - 'EnsembleClassifier', - 'EnsembleRegressor', + #'EnsembleClassifier', + #'EnsembleRegressor', + 'FastForestBinaryClassifier', 'FastForestRegressor', + 'FastLinearBinaryClassifier', + 'FastLinearClassifier', 'FastLinearRegressor', + 'FastTreesBinaryClassifier', 'FastTreesRegressor', 'FastTreesTweedieRegressor', - 'GamRegressor', + 'FromKey', + 'GlobalContrastRowScaler', + 'Handler', 'Indicator', 'KMeansPlusPlus', 'LightGbmBinaryClassifier', 'LightGbmClassifier', + 'LightGbmRanker', 'LightGbmRegressor', + 'LinearSvmBinaryClassifier', + 'LogisticRegressionBinaryClassifier', + 'LogisticRegressionClassifier', 'LpScaler', 'MeanVarianceScaler', 'MinMaxScaler', + #'MutualInformationSelector', 'NaiveBayesClassifier', 'OneHotVectorizer', 'OnlineGradientDescentRegressor', 'OrdinaryLeastSquaresRegressor', - 'PcaAnomalyDetector', 'PoissonRegressionRegressor', 'PrefixColumnConcatenator', 'TypeConverter', @@ -333,6 +425,48 @@ def load_json(file_path): return json.loads(content_without_comments) +def print_results(result_expected, result_onnx): + print("\nML.Net Output (Expected Result):") + print(result_expected) + if not isinstance(result_expected, pd.Series): + print('Columns', result_expected.columns) + + print("\nOnnxRunner Result:") + print(result_onnx) + if not isinstance(result_onnx, pd.Series): + print('Columns', result_onnx.columns) + + +def validate_results(class_name, result_expected, result_onnx): + if not class_name in EXPECTED_RESULTS: + raise RuntimeError("ERROR: ONNX model executed but no results specified for comparison.") + + if 'num_cols' in EXPECTED_RESULTS[class_name]: + num_cols = EXPECTED_RESULTS[class_name]['num_cols'] + + if len(result_expected.columns) != num_cols: + raise RuntimeError("ERROR: The ML.Net output does not contain the expected number of columns.") + + if len(result_onnx.columns) != num_cols: + raise RuntimeError("ERROR: The ONNX output does not contain the expected number of columns.") + + for col_pair in EXPECTED_RESULTS[class_name]['cols']: + col_expected = result_expected.loc[:, col_pair[0]] + col_onnx = result_onnx.loc[:, col_pair[1]] + + try: + pd.testing.assert_series_equal(col_expected, + col_onnx, + check_names=False, + check_exact=False, + check_less_precise=True) + except Exception as e: + print(e) + raise RuntimeError("ERROR: OnnxRunner result does not match expected result.") + + return True + + def test_export_to_onnx(estimator, class_name): """ Fit and test an estimator and determine @@ -343,12 +477,14 @@ def test_export_to_onnx(estimator, class_name): output = None exported = False + export_valid = False try: dataset = DATASETS.get(class_name, iris_df) estimator.fit(dataset) - onnx_version = 'Experimental' if class_name in REQUIRES_EXPERIMENTAL else 'Stable' + onnx_version = 'Experimental' if class_name in REQUIRES_EXPERIMENTAL \ + else 'Stable' with CaptureOutputContext() as output: estimator.export_to_onnx(onnx_path, @@ -365,15 +501,41 @@ def test_export_to_onnx(estimator, class_name): (onnx_file_size != 0) and (onnx_json_file_size != 0) and (not 'cannot save itself as ONNX' in output.stdout)): - exported = True - if exported and SHOW_ONNX_JSON: - with open(onnx_json_path) as f: - print(json.dumps(json.load(f), indent=4)) + exported = True + + print('ONNX model path:', onnx_path) + + if SHOW_ONNX_JSON: + with open(onnx_json_path) as f: + print(json.dumps(json.load(f), indent=4)) + + # Verify that the output of the exported onnx graph + # produces the same results as the standard estimators. + if isinstance(estimator, BasePredictor): + result_expected = estimator.predict(dataset) + else: + result_expected = estimator.transform(dataset) + + if isinstance(result_expected, pd.Series): + result_expected = pd.DataFrame(result_expected) + + try: + onnxrunner = OnnxRunner(model_file=onnx_path) + result_onnx = onnxrunner.fit_transform(dataset) + + if SHOW_TRANSFORMED_RESULTS: + print_results(result_expected, result_onnx) + + export_valid = validate_results(class_name, + result_expected, + result_onnx) + except Exception as e: + print(e) os.remove(onnx_path) os.remove(onnx_json_path) - return exported + return {'exported': exported, 'export_valid': export_valid} manifest_diff = os.path.join(script_dir, '..', 'tools', 'manifest_diff.json') @@ -383,10 +545,14 @@ def test_export_to_onnx(estimator, class_name): exportable_estimators = set() exportable_experimental_estimators = set() unexportable_estimators = set() +runable_estimators = set() for entry_point in entry_points: class_name = entry_point['NewName'] +# if not class_name in ['PcaTransformer']: +# continue + print('\n===========> %s' % class_name) if class_name in SKIP: @@ -404,7 +570,7 @@ def test_export_to_onnx(estimator, class_name): result = test_export_to_onnx(estimator, class_name) - if result: + if result['exported']: if class_name in REQUIRES_EXPERIMENTAL: exportable_experimental_estimators.add(class_name) else: @@ -416,6 +582,10 @@ def test_export_to_onnx(estimator, class_name): unexportable_estimators.add(class_name) print('Estimator could NOT be exported to ONNX.') + if result['export_valid']: + runable_estimators.add(class_name) + print('Exported ONNX model successfully transformed with OnnxRunner.') + print('\nThe following estimators were skipped: ') pprint.pprint(sorted(SKIP)) @@ -428,10 +598,14 @@ def test_export_to_onnx(estimator, class_name): print('\nThe following estimators could not be exported to ONNX: ') pprint.pprint(sorted(unexportable_estimators)) -failed_estimators = SUPPORTED_ESTIMATORS.difference(exportable_estimators) \ - .difference(exportable_experimental_estimators) +failed_estimators = SUPPORTED_ESTIMATORS.difference(runable_estimators) +print("\nThe following tests failed exporting to ONNX:") +pprint.pprint(sorted(failed_estimators)) + +print('\nThe following estimators successfully completed the end to end test: ') +pprint.pprint(sorted(runable_estimators)) +print() if len(failed_estimators) > 0: - print("The following tests failed exporting to onnx:", sorted(failed_estimators)) - raise RuntimeError("onnx export checks failed") + raise RuntimeError("ONNX export checks failed")