@@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model}
2323import org .apache .spark .ml .attribute .NominalAttribute
2424import org .apache .spark .ml .param ._
2525import org .apache .spark .ml .param .shared ._
26- import org .apache .spark .ml .util .SchemaUtils
2726import org .apache .spark .sql .DataFrame
2827import org .apache .spark .sql .functions ._
29- import org .apache .spark .sql .types .{StringType , StructType }
28+ import org .apache .spark .sql .types .{NumericType , StringType , StructType }
3029import org .apache .spark .util .collection .OpenHashMap
3130
3231/**
@@ -37,7 +36,10 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
3736 /** Validates and transforms the input schema. */
3837 protected def validateAndTransformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
3938 val map = extractParamMap(paramMap)
40- SchemaUtils .checkColumnType(schema, map(inputCol), StringType )
39+ val inputColName = map(inputCol)
40+ val inputDataType = schema(inputColName).dataType
41+ require(inputDataType == StringType || inputDataType.isInstanceOf [NumericType ],
42+ s " The input column $inputColName must be either string type or numeric type. " )
4143 val inputFields = schema.fields
4244 val outputColName = map(outputCol)
4345 require(inputFields.forall(_.name != outputColName),
@@ -51,6 +53,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
5153/**
5254 * :: AlphaComponent ::
5355 * A label indexer that maps a string column of labels to an ML column of label indices.
56+ * If the input column is numeric, we cast it to string and index the string values.
5457 * The indices are in [0, numLabels), ordered by label frequencies.
5558 * So the most frequent label gets index 0.
5659 */
@@ -67,7 +70,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase
6770
6871 override def fit (dataset : DataFrame , paramMap : ParamMap ): StringIndexerModel = {
6972 val map = extractParamMap(paramMap)
70- val counts = dataset.select(map(inputCol)).map(_.getString(0 )).countByValue()
73+ val counts = dataset.select(col(map(inputCol)).cast(StringType ))
74+ .map(_.getString(0 ))
75+ .countByValue()
7176 val labels = counts.toSeq.sortBy(- _._2).map(_._1).toArray
7277 val model = new StringIndexerModel (this , map, labels)
7378 Params .inheritValues(map, this , model)
@@ -119,7 +124,8 @@ class StringIndexerModel private[ml] (
119124 val outputColName = map(outputCol)
120125 val metadata = NominalAttribute .defaultAttr
121126 .withName(outputColName).withValues(labels).toMetadata()
122- dataset.select(col(" *" ), indexer(dataset(map(inputCol))).as(outputColName, metadata))
127+ dataset.select(col(" *" ),
128+ indexer(dataset(map(inputCol)).cast(StringType )).as(outputColName, metadata))
123129 }
124130
125131 override def transformSchema (schema : StructType , paramMap : ParamMap ): StructType = {
0 commit comments