@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
2323import org .apache .spark .ml .attribute .NominalAttribute
2424import org .apache .spark .ml .param ._
2525import org .apache .spark .ml .param .shared ._
26- import org .apache .spark .ml .util .SchemaUtils
2726import org .apache .spark .sql .DataFrame
2827import org .apache .spark .sql .functions ._
29- import org .apache .spark .sql .types .{StringType , StructType }
28+ import org .apache .spark .sql .types .{NumericType , StringType , StructType }
3029import org .apache .spark .util .collection .OpenHashMap
3130
3231/**
@@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
3736 /** Validates and transforms the input schema. */
3837 protected def validateAndTransformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
3938 val map = extractParamMap(paramMap)
40- SchemaUtils .checkColumnType(schema, map(inputCol), StringType )
39+ val inputColName = map(inputCol)
40+ val inputDataType = schema(inputColName).dataType
41+ require(inputDataType == StringType || inputDataType.isInstanceOf [NumericType ],
42+ s " The input column $inputColName must be either string type or numeric type, " +
43+ s " but got $inputDataType. " )
4144 val inputFields = schema.fields
4245 val outputColName = map(outputCol)
4346 require(inputFields.forall(_.name != outputColName),
@@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
5154/**
5255 * :: AlphaComponent ::
5356 * A label indexer that maps a string column of labels to an ML column of label indices.
57+ * If the input column is numeric, we cast it to string and index the string values.
5458 * The indices are in [0, numLabels), ordered by label frequencies.
5559 * So the most frequent label gets index 0.
5660 */
@@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
6771
6872 override def fit (dataset : DataFrame , paramMap : ParamMap ): StringIndexerModel = {
6973 val map = extractParamMap(paramMap)
70- val counts = dataset.select(map(inputCol)).map(_.getString(0 )).countByValue()
74+ val counts = dataset.select(col(map(inputCol)).cast(StringType ))
75+ .map(_.getString(0 ))
76+ .countByValue()
7177 val labels = counts.toSeq.sortBy(- _._2).map(_._1).toArray
7278 val model = new StringIndexerModel (this , map, labels)
7379 Params .inheritValues(map, this , model)
@@ -119,7 +125,8 @@ class StringIndexerModel private[ml] (
119125 val outputColName = map(outputCol)
120126 val metadata = NominalAttribute .defaultAttr
121127 .withName(outputColName).withValues(labels).toMetadata()
122- dataset.select(col(" *" ), indexer(dataset(map(inputCol))).as(outputColName, metadata))
128+ dataset.select(col(" *" ),
129+ indexer(dataset(map(inputCol)).cast(StringType )).as(outputColName, metadata))
123130 }
124131
125132 override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
0 commit comments