@@ -1917,8 +1917,7 @@ def mean(self):
19171917
19181918
19191919@inherit_doc
1920- class StringIndexer (JavaEstimator , HasInputCol , HasOutputCol , HasHandleInvalid , JavaMLReadable ,
1921- JavaMLWritable ):
1920+ class StringIndexer (JavaEstimator , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
19221921 """
19231922 A label indexer that maps a string column of labels to an ML column of label indices.
19241923 If the input column is numeric, we cast it to string and index the string values.
@@ -1936,6 +1935,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
19361935 >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
19371936 ... key=lambda x: x[0])
19381937 [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
1938+ >>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"),
1939+ ... Row(id=2, label="e")], 2)
1940+ >>> dfKeep= spark.createDataFrame(testData2)
1941+ >>> tdKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf).transform(dfKeep)
1942+ >>> itdKeep = inverter.transform(tdKeep)
1943+ >>> sorted(set([(i[0], str(i[1])) for i in itdKeep.select(itdKeep.id, itdKeep.label2).collect()]),
1944+ ... key=lambda x: x[0])
1945+ [(0, 'a'), (6, 'd'), (6, 'e')]
19391946 >>> stringIndexerPath = temp_path + "/string-indexer"
19401947 >>> stringIndexer.save(stringIndexerPath)
19411948 >>> loadedIndexer = StringIndexer.load(stringIndexerPath)
@@ -1955,6 +1962,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
19551962 .. versionadded:: 1.4.0
19561963 """
19571964
1965+ handleInvalid = Param (Params ._dummy (), "handleInvalid" , "how to handle unseen labels. " +
1966+ "Options are 'skip' (filter out rows with unseen labels), " +
1967+ "error (throw an error), or 'keep' (put unseen labels in a special " +
1968+ "additional bucket, at index numLabels)." ,
1969+ typeConverter = TypeConverters .toString )
19581970 @keyword_only
19591971 def __init__ (self , inputCol = None , outputCol = None , handleInvalid = "error" ):
19601972 """
@@ -1979,6 +1991,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
19791991 def _create_model (self , java_model ):
19801992 return StringIndexerModel (java_model )
19811993
1994+ @since ("2.2.0" )
1995+ def setHandleInvalid (self , value ):
1996+ """
1997+ Sets the value of :py:attr:`handleInvalid`.
1998+ """
1999+ return self ._set (handleInvalid = value )
2000+
2001+ @since ("2.2.0" )
2002+ def getHandleInvalid (self ):
2003+ """
2004+ Gets the value of :py:attr:`handleInvalid` or its default value.
2005+ """
2006+ return self .getOrDefault (self .handleInvalid )
2007+
19822008
19832009class StringIndexerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
19842010 """
0 commit comments