From 576e530867cd6f2f7cc740ceb4cffd1f77721c1e Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 14:50:41 +0800 Subject: [PATCH 1/7] [SPARK-19852][PYSPARK][ML] Update Python API for StringIndexer setHandleInvalid This PR reflect the changes made in SPARK-17498 on pyspark to support a new option 'keep' in StringIndexer to handle unseen labels Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 77de1cc18246..347e1a8c5b58 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2077,8 +2077,7 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, - JavaMLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -2098,6 +2097,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] + >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), + ... Row(id=2, label="e")], 2) + >>> dfKeep= spark.createDataFrame(testData2) + >>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) + >>> itdKeep = inverter.transform(tdKeep) + >>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]), + ... key=lambda x: x[0]) + [(0, 'a'), (6, 'd'), (6, 'e')] >>> stringIndexerPath = temp_path + "/string-indexer" >>> stringIndexer.save(stringIndexerPath) >>> loadedIndexer = StringIndexer.load(stringIndexerPath) @@ -2132,6 +2139,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle unseen labels. " + + "Options are 'skip' (filter out rows with unseen labels), " + + "error (throw an error), or 'keep' (put unseen labels in a special " + + "additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): @@ -2174,6 +2187,20 @@ def getStringOrderType(self): """ return self.getOrDefault(self.stringOrderType) + @since("2.2.0") + def setHandleInvalid(self, value): + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + + @since("2.2.0") + def getHandleInvalid(self): + """ + Gets the value of :py:attr:`handleInvalid` or its default value. + """ + return self.getOrDefault(self.handleInvalid) + class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ From 8c0483017e7d5bed880a2f1917e0be4b0024a3c2 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 15:21:47 +0800 Subject: [PATCH 2/7] fix compilation issues Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 347e1a8c5b58..fb552e74c767 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2100,9 +2100,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), ... Row(id=2, label="e")], 2) >>> dfKeep= spark.createDataFrame(testData2) - >>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) - >>> itdKeep = inverter.transform(tdKeep) - >>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]), + >>> tdK = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) + >>> itdK = inverter.transform(tdK) + >>> sorted(set([(i[0], str(i[1])) for i in itdK.select(itdK.id, itdK.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (6, 'd'), (6, 'e')] >>> stringIndexerPath = temp_path + "/string-indexer" From 79dbf81fb26d4c1d85f68d0b9dec73b8810f5157 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 16:20:14 +0800 Subject: [PATCH 3/7] doctest Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index fb552e74c767..be99a0299f06 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2100,8 +2100,10 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), ... Row(id=2, label="e")], 2) >>> dfKeep= spark.createDataFrame(testData2) - >>> tdK = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep) - >>> itdK = inverter.transform(tdK) + >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) + >>> tdK = modelKeep.transform(dfKeep) + >>> itdK = IndexToString(inputCol="indexed", outputCol="label2", + ... labels=modelKeep.labels).transform(tdK) >>> sorted(set([(i[0], str(i[1])) for i in itdK.select(itdK.id, itdK.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (6, 'd'), (6, 'e')] From 764e099aadf36b18ea08eb03040d44c71659ba08 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 10 Mar 2017 16:50:09 +0800 Subject: [PATCH 4/7] update doctest Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index be99a0299f06..1d50ac87566a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2102,11 +2102,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja >>> dfKeep= spark.createDataFrame(testData2) >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) >>> tdK = modelKeep.transform(dfKeep) - >>> itdK = IndexToString(inputCol="indexed", outputCol="label2", - ... labels=modelKeep.labels).transform(tdK) - >>> sorted(set([(i[0], str(i[1])) for i in itdK.select(itdK.id, itdK.label2).collect()]), + >>> sorted(set([(i[0], i[1]) for i in tdK.select(tdK.id, tdK.indexed).collect()]), ... key=lambda x: x[0]) - [(0, 'a'), (6, 'd'), (6, 'e')] + [(0, 0.0), (1, 3.0), (2, 3.0)] >>> stringIndexerPath = temp_path + "/string-indexer" >>> stringIndexer.save(stringIndexerPath) >>> loadedIndexer = StringIndexer.load(stringIndexerPath) From 9809be98c9e9fdb2d2e22900341d20821716dc94 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Fri, 17 Mar 2017 10:55:15 +0800 Subject: [PATCH 5/7] include changes made by SPARK-11569 Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1d50ac87566a..a1062b06aede 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2098,7 +2098,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), - ... Row(id=2, label="e")], 2) + ... Row(id=2, label=None)], 2) >>> dfKeep= spark.createDataFrame(testData2) >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) >>> tdK = modelKeep.transform(dfKeep) @@ -2133,16 +2133,17 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja .. versionadded:: 1.4.0 """ + stringOrderType = Param(Params._dummy(), "stringOrderType", "How to order labels of string column. The first label after " + "ordering is assigned an index of 0. Supported options: " + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", typeConverter=TypeConverters.toString) - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle unseen labels. " + - "Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (put unseen labels in a special " + - "additional bucket, at index numLabels).", + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", typeConverter=TypeConverters.toString) @keyword_only From 6f3bea8f1e4fcb9c45dc22a542c8e43801bfa5f8 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 28 Jun 2017 18:12:33 +0800 Subject: [PATCH 6/7] Move doc tests to tests.py. --- python/pyspark/ml/feature.py | 13 ++----------- python/pyspark/ml/tests.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a1062b06aede..17802f438d2c 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2097,14 +2097,6 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] - >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"), - ... Row(id=2, label=None)], 2) - >>> dfKeep= spark.createDataFrame(testData2) - >>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf) - >>> tdK = modelKeep.transform(dfKeep) - >>> sorted(set([(i[0], i[1]) for i in tdK.select(tdK.id, tdK.indexed).collect()]), - ... key=lambda x: x[0]) - [(0, 0.0), (1, 3.0), (2, 3.0)] >>> stringIndexerPath = temp_path + "/string-indexer" >>> stringIndexer.save(stringIndexerPath) >>> loadedIndexer = StringIndexer.load(stringIndexerPath) @@ -2133,7 +2125,6 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja .. versionadded:: 1.4.0 """ - stringOrderType = Param(Params._dummy(), "stringOrderType", "How to order labels of string column. The first label after " + "ordering is assigned an index of 0. Supported options: " + @@ -2188,14 +2179,14 @@ def getStringOrderType(self): """ return self.getOrDefault(self.stringOrderType) - @since("2.2.0") + @since("2.3.0") def setHandleInvalid(self, value): """ Sets the value of :py:attr:`handleInvalid`. """ return self._set(handleInvalid=value) - @since("2.2.0") + @since("2.3.0") def getHandleInvalid(self): """ Gets the value of :py:attr:`handleInvalid` or its default value. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 17a39472e1fe..ffb8b0a890ff 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -551,6 +551,27 @@ def test_rformula_string_indexer_order_type(self): for i in range(0, len(expected)): self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + class HasInducedError(Params): From 320fe08fd16554400b59abc9287ac5280ab4a3df Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sat, 1 Jul 2017 16:19:17 +0800 Subject: [PATCH 7/7] To override param. --- python/pyspark/ml/feature.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 17802f438d2c..25ad06f682ed 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2077,7 +2077,8 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. @@ -2179,20 +2180,6 @@ def getStringOrderType(self): """ return self.getOrDefault(self.stringOrderType) - @since("2.3.0") - def setHandleInvalid(self, value): - """ - Sets the value of :py:attr:`handleInvalid`. - """ - return self._set(handleInvalid=value) - - @since("2.3.0") - def getHandleInvalid(self): - """ - Gets the value of :py:attr:`handleInvalid` or its default value. - """ - return self.getOrDefault(self.handleInvalid) - class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """