diff --git a/src/python/nimbusml.pyproj b/src/python/nimbusml.pyproj
index fb4a8e03..5daea049 100644
--- a/src/python/nimbusml.pyproj
+++ b/src/python/nimbusml.pyproj
@@ -293,6 +293,7 @@
+
@@ -611,6 +612,7 @@
+
@@ -666,6 +668,7 @@
+
diff --git a/src/python/nimbusml/internal/core/preprocessing/datasettransformer.py b/src/python/nimbusml/internal/core/preprocessing/datasettransformer.py
new file mode 100644
index 00000000..545e6e36
--- /dev/null
+++ b/src/python/nimbusml/internal/core/preprocessing/datasettransformer.py
@@ -0,0 +1,49 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------------------------
+# - Generated by tools/entrypoint_compiler.py: do not edit by hand
+"""
+DatasetTransformer
+"""
+
+__all__ = ["DatasetTransformer"]
+
+
+from ...entrypoints.models_datasettransformer import models_datasettransformer
+from ...utils.utils import trace
+from ..base_pipeline_item import BasePipelineItem, DefaultSignature
+
+
+class DatasetTransformer(BasePipelineItem, DefaultSignature):
+ """
+ **Description**
+ Applies a TransformModel to a dataset.
+
+ :param transform_model: Transform model.
+
+ :param params: Additional arguments sent to compute engine.
+
+ """
+
+ @trace
+ def __init__(
+ self,
+ transform_model,
+ **params):
+ BasePipelineItem.__init__(
+ self, type='transform', **params)
+
+ self.transform_model = transform_model
+
+ @property
+ def _entrypoint(self):
+ return models_datasettransformer
+
+ @trace
+ def _get_node(self, **all_args):
+ algo_args = dict(
+ transform_model=self.transform_model)
+
+ all_args.update(algo_args)
+ return self._entrypoint(**all_args)
diff --git a/src/python/nimbusml/pipeline.py b/src/python/nimbusml/pipeline.py
index c1143160..4824e83c 100644
--- a/src/python/nimbusml/pipeline.py
+++ b/src/python/nimbusml/pipeline.py
@@ -599,6 +599,14 @@ def _init_graph_nodes(
output_data=output_data,
output_model=output_model,
strategy_iosklearn=strategy_iosklearn)
+
+ for node in enumerate([n for n in transform_nodes
+ if n.name == 'Models.DatasetTransformer']):
+ input_name = 'dataset_transformer_model' + str(node[0])
+ inputs[input_name] = node[1].inputs['TransformModel']
+ node[1].inputs['TransformModel'] = '$' + input_name
+ node[1].input_variables.add(node[1].inputs['TransformModel'])
+
graph_nodes['transform_nodes'] = transform_nodes
return graph_nodes, feature_columns, inputs, transform_nodes, \
columns_out
@@ -778,9 +786,13 @@ def _fit_graph(self, X, y, verbose, **params):
graph_nodes = list(itertools.chain(*graph_nodes.values()))
# combine output models
- transform_models = [node.outputs["Model"]
- for node in graph_nodes if
- "Model" in node.outputs]
+ transform_models = []
+ for node in graph_nodes:
+ if node.name == 'Models.DatasetTransformer':
+ transform_models.append(node.inputs['TransformModel'])
+ elif "Model" in node.outputs:
+ transform_models.append(node.outputs["Model"])
+
if learner_node and len(
transform_models) > 0: # no need to combine if there is
# only 1 model
diff --git a/src/python/nimbusml/preprocessing/__init__.py b/src/python/nimbusml/preprocessing/__init__.py
index e3d98fca..26b41b8e 100644
--- a/src/python/nimbusml/preprocessing/__init__.py
+++ b/src/python/nimbusml/preprocessing/__init__.py
@@ -1,9 +1,11 @@
from .fromkey import FromKey
from .tokey import ToKey
from .tensorflowscorer import TensorFlowScorer
+from .datasettransformer import DatasetTransformer
__all__ = [
'FromKey',
'ToKey',
- 'TensorFlowScorer'
+ 'TensorFlowScorer',
+ 'DatasetTransformer'
]
diff --git a/src/python/nimbusml/preprocessing/datasettransformer.py b/src/python/nimbusml/preprocessing/datasettransformer.py
new file mode 100644
index 00000000..0a07dbdb
--- /dev/null
+++ b/src/python/nimbusml/preprocessing/datasettransformer.py
@@ -0,0 +1,54 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------------------------
+# - Generated by tools/entrypoint_compiler.py: do not edit by hand
+"""
+DatasetTransformer
+"""
+
+__all__ = ["DatasetTransformer"]
+
+
+from sklearn.base import TransformerMixin
+
+from ..base_transform import BaseTransform
+from ..internal.core.preprocessing.datasettransformer import \
+ DatasetTransformer as core
+from ..internal.utils.utils import trace
+
+
+class DatasetTransformer(core, BaseTransform, TransformerMixin):
+ """
+ **Description**
+ Applies a TransformModel to a dataset.
+
+ :param columns: see `Columns `_.
+
+ :param transform_model: Transform model.
+
+ :param params: Additional arguments sent to compute engine.
+
+ """
+
+ @trace
+ def __init__(
+ self,
+ transform_model,
+ columns=None,
+ **params):
+
+ if columns:
+ params['columns'] = columns
+ BaseTransform.__init__(self, **params)
+ core.__init__(
+ self,
+ transform_model=transform_model,
+ **params)
+ self._columns = columns
+
+ def get_params(self, deep=False):
+ """
+ Get the parameters for this operator.
+ """
+ return core.get_params(self)
diff --git a/src/python/nimbusml/tests/preprocessing/test_datasettransformer.py b/src/python/nimbusml/tests/preprocessing/test_datasettransformer.py
new file mode 100644
index 00000000..197119c6
--- /dev/null
+++ b/src/python/nimbusml/tests/preprocessing/test_datasettransformer.py
@@ -0,0 +1,184 @@
+# --------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------------------------
+import os
+import unittest
+
+import numpy as np
+import pandas as pd
+from nimbusml import Pipeline, FileDataStream
+from nimbusml.datasets import get_dataset
+from nimbusml.feature_extraction.categorical import OneHotVectorizer
+from nimbusml.linear_model import LogisticRegressionBinaryClassifier, OnlineGradientDescentRegressor
+from nimbusml.preprocessing import DatasetTransformer
+from nimbusml.preprocessing.filter import RangeFilter
+from nimbusml import FileDataStream
+
+seed = 0
+
+train_data = {'c0': ['a', 'b', 'a', 'b'],
+ 'c1': [1, 2, 3, 4],
+ 'c2': [2, 3, 4, 5]}
+train_df = pd.DataFrame(train_data).astype({'c1': np.float64,
+ 'c2': np.float64})
+
+test_data = {'c0': ['a', 'b', 'b'],
+ 'c1': [1.5, 2.3, 3.7],
+ 'c2': [2.2, 4.9, 2.7]}
+test_df = pd.DataFrame(test_data).astype({'c1': np.float64,
+ 'c2': np.float64})
+
+
+class TestDatasetTransformer(unittest.TestCase):
+
+ def test_same_schema_with_dataframe_input(self):
+ train_df_updated = train_df.drop(['c0'], axis=1)
+ test_df_updated = test_df.drop(['c0'], axis=1)
+
+ rf_max = 4.5
+
+ # Create reference pipeline
+ std_pipeline = Pipeline([
+ RangeFilter(min=0.0, max=rf_max) << 'c2',
+ OnlineGradientDescentRegressor(label='c2', feature=['c1'])
+ ], random_state=seed)
+
+ std_pipeline.fit(train_df_updated)
+ result_1 = std_pipeline.predict(test_df_updated)
+
+ # Create combined pipeline
+ transform_pipeline = Pipeline([RangeFilter(min=0.0, max=rf_max) << 'c2'])
+ transform_pipeline.fit(train_df_updated)
+
+ combined_pipeline = Pipeline([
+ DatasetTransformer(transform_model=transform_pipeline.model),
+ OnlineGradientDescentRegressor(label='c2', feature=['c1'])
+ ], random_state=seed)
+ combined_pipeline.fit(train_df_updated)
+
+ os.remove(transform_pipeline.model)
+
+ result_2 = combined_pipeline.predict(test_df_updated)
+
+ self.assertTrue(result_1.equals(result_2))
+
+
+ def test_different_schema_with_dataframe_input(self):
+ # Create reference pipeline
+ std_pipeline = Pipeline([
+ OneHotVectorizer() << 'c0',
+ OnlineGradientDescentRegressor(label='c2', feature=['c0', 'c1'])
+ ], random_state=seed)
+
+ std_pipeline.fit(train_df)
+ result_1 = std_pipeline.predict(test_df)
+
+ # Create combined pipeline
+ transform_pipeline = Pipeline([OneHotVectorizer() << 'c0'], random_state=seed)
+ transform_pipeline.fit(train_df)
+
+ combined_pipeline = Pipeline([
+ DatasetTransformer(transform_model=transform_pipeline.model),
+ OnlineGradientDescentRegressor(label='c2', feature=['c0', 'c1'])
+ ], random_state=seed)
+ combined_pipeline.fit(train_df)
+
+ os.remove(transform_pipeline.model)
+
+ result_2 = combined_pipeline.predict(test_df)
+
+ self.assertTrue(result_1.equals(result_2))
+
+
+ def test_different_schema_with_filedatastream_input(self):
+ train_filename = "train-data.csv"
+ train_df.to_csv(train_filename, index=False, header=True)
+ train_data_stream = FileDataStream.read_csv(train_filename, sep=',', header=True)
+
+ test_filename = "test-data.csv"
+ test_df.to_csv(test_filename, index=False, header=True)
+ test_data_stream = FileDataStream.read_csv(test_filename, sep=',', header=True)
+
+ # Create reference pipeline
+ std_pipeline = Pipeline([
+ OneHotVectorizer() << 'c0',
+ OnlineGradientDescentRegressor(label='c2', feature=['c0', 'c1'])
+ ], random_state=seed)
+
+ std_pipeline.fit(train_data_stream)
+ result_1 = std_pipeline.predict(test_data_stream)
+
+ # Create combined pipeline
+ transform_pipeline = Pipeline([OneHotVectorizer() << 'c0'], random_state=seed)
+ transform_pipeline.fit(train_data_stream)
+
+ combined_pipeline = Pipeline([
+ DatasetTransformer(transform_model=transform_pipeline.model),
+ OnlineGradientDescentRegressor(label='c2', feature=['c0', 'c1'])
+ ], random_state=seed)
+ combined_pipeline.fit(train_data_stream)
+
+ os.remove(transform_pipeline.model)
+
+ result_2 = combined_pipeline.predict(test_data_stream)
+
+ self.assertTrue(result_1.equals(result_2))
+
+ os.remove(train_filename)
+ os.remove(test_filename)
+
+
+ def test_combining_two_dataset_transformers(self):
+ rf_max = 4.5
+
+ # Create reference pipeline
+ std_pipeline = Pipeline([
+ RangeFilter(min=0.0, max=rf_max) << 'c2',
+ OneHotVectorizer() << 'c0',
+ OnlineGradientDescentRegressor(label='c2', feature=['c0', 'c1'])
+ ], random_state=seed)
+
+ std_pipeline.fit(train_df)
+ result_1 = std_pipeline.predict(test_df)
+
+ # Create combined pipeline
+ transform_pipeline1 = Pipeline([RangeFilter(min=0.0, max=rf_max) << 'c2'])
+ transform_pipeline1.fit(train_df)
+
+ transform_pipeline2 = Pipeline([OneHotVectorizer() << 'c0'], random_state=seed)
+ transform_pipeline2.fit(train_df)
+
+ combined_pipeline = Pipeline([
+ DatasetTransformer(transform_model=transform_pipeline1.model),
+ DatasetTransformer(transform_model=transform_pipeline2.model),
+ OnlineGradientDescentRegressor(label='c2', feature=['c0', 'c1'])
+ ], random_state=seed)
+ combined_pipeline.fit(train_df)
+
+ os.remove(transform_pipeline1.model)
+ os.remove(transform_pipeline2.model)
+
+ result_2 = combined_pipeline.predict(test_df)
+
+ self.assertTrue(result_1.equals(result_2))
+
+
+ def test_get_fit_info(self):
+ transform_pipeline = Pipeline([RangeFilter(min=0.0, max=4.5) << 'c2'])
+ transform_pipeline.fit(train_df)
+
+ combined_pipeline = Pipeline([
+ DatasetTransformer(transform_model=transform_pipeline.model),
+ OnlineGradientDescentRegressor(label='c2', feature=['c1'])
+ ], random_state=seed)
+ combined_pipeline.fit(train_df)
+
+ info = combined_pipeline.get_fit_info(train_df)
+
+ self.assertTrue(info[0][1]['name'] == 'DatasetTransformer')
+
+
+if __name__ == '__main__':
+ unittest.main()
+
diff --git a/src/python/tests/test_estimator_checks.py b/src/python/tests/test_estimator_checks.py
index 7fe31334..89cfcda9 100644
--- a/src/python/tests/test_estimator_checks.py
+++ b/src/python/tests/test_estimator_checks.py
@@ -260,7 +260,14 @@ def load_json(file_path):
return json.loads(content_without_comments)
-skip_epoints = set(['OneVsRestClassifier', 'TreeFeaturizer'])
+skip_epoints = set([
+ 'OneVsRestClassifier',
+ 'TreeFeaturizer',
+ # skip SymSgdBinaryClassifier for now, because of crashes.
+ 'SymSgdBinaryClassifier',
+ 'DatasetTransformer'
+])
+
epoints = []
my_path = os.path.realpath(__file__)
my_dir = os.path.dirname(my_path)
@@ -287,9 +294,6 @@ def load_json(file_path):
# skip LighGbm for now, because of random crashes.
if 'LightGbm' in class_name:
continue
- # skip SymSgdBinaryClassifier for now, because of crashes.
- if 'SymSgdBinaryClassifier' in class_name:
- continue
mod = __import__('nimbusml.' + e[0], fromlist=[str(class_name)])
the_class = getattr(mod, class_name)
diff --git a/src/python/tools/manifest_diff.json b/src/python/tools/manifest_diff.json
index acff52df..d8a64d82 100644
--- a/src/python/tools/manifest_diff.json
+++ b/src/python/tools/manifest_diff.json
@@ -317,6 +317,12 @@
}
]
},
+ {
+ "Name": "Models.DatasetTransformer",
+ "NewName": "DatasetTransformer",
+ "Module": "preprocessing",
+ "Type": "Transform"
+ },
{
"Name": "Trainers.FieldAwareFactorizationMachineBinaryClassifier",
"NewName": "FactorizationMachineBinaryClassifier",