Skip to content

Commit d94dc68

Browse files
author
VinceShieh
committed
[SPARK-19852][PYSPARK][ML] Update Python API for StringIndexer setHandleInvalid
This PR reflect the changes made in SPARK-17498 on pyspark to support a new option 'keep' in StringIndexer to handle unseen labels Signed-off-by: VinceShieh <[email protected]>
1 parent d809cee commit d94dc68

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

python/pyspark/ml/feature.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,8 +1917,7 @@ def mean(self):
19171917

19181918

19191919
@inherit_doc
1920-
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
1921-
JavaMLWritable):
1920+
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
19221921
"""
19231922
A label indexer that maps a string column of labels to an ML column of label indices.
19241923
If the input column is numeric, we cast it to string and index the string values.
@@ -1936,6 +1935,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
19361935
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
19371936
... key=lambda x: x[0])
19381937
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
1938+
>>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"),
1939+
... Row(id=2, label="e")], 2)
1940+
>>> dfKeep= spark.createDataFrame(testData2)
1941+
>>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep)
1942+
>>> itdKeep = inverter.transform(tdKeep)
1943+
>>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]),
1944+
... key=lambda x: x[0])
1945+
[(0, 'a'), (6, 'd'), (6, 'e')]
19391946
>>> stringIndexerPath = temp_path + "/string-indexer"
19401947
>>> stringIndexer.save(stringIndexerPath)
19411948
>>> loadedIndexer = StringIndexer.load(stringIndexerPath)
@@ -1955,6 +1962,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
19551962
.. versionadded:: 1.4.0
19561963
"""
19571964

1965+
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle unseen labels. " +
1966+
"Options are 'skip' (filter out rows with unseen labels), " +
1967+
"error (throw an error), or 'keep' (put unseen labels in a special " +
1968+
"additional bucket, at index numLabels).",
1969+
typeConverter=TypeConverters.toString)
19581970
@keyword_only
19591971
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
19601972
"""
@@ -1979,6 +1991,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
19791991
def _create_model(self, java_model):
19801992
return StringIndexerModel(java_model)
19811993

1994+
@since("2.2.0")
1995+
def setHandleInvalid(self, value):
1996+
"""
1997+
Sets the value of :py:attr:`handleInvalid`.
1998+
"""
1999+
return self._set(handleInvalid=value)
2000+
2001+
@since("2.2.0")
2002+
def getHandleInvalid(self):
2003+
"""
2004+
Gets the value of :py:attr:`handleInvalid` or its default value.
2005+
"""
2006+
return self.getOrDefault(self.handleInvalid)
2007+
19822008

19832009
class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
19842010
"""

0 commit comments

Comments
 (0)