-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11215][ML] Add multiple columns support to StringIndexer #19621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
b227e3b
6a17617
8e71b45
b0b14b0
77bea32
e5db190
031f53f
66d054a
0bd9f66
bb209c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,18 +26,19 @@ import org.apache.spark.annotation.Since | |
| import org.apache.spark.ml.{Estimator, Model, Transformer} | ||
| import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.util.VersionUtils.majorMinorVersion | ||
| import org.apache.spark.util.collection.OpenHashMap | ||
|
|
||
| /** | ||
| * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. | ||
| */ | ||
| private[feature] trait StringIndexerBase extends Params with HasHandleInvalid with HasInputCol | ||
| with HasOutputCol { | ||
| with HasOutputCol with HasInputCols with HasOutputCols { | ||
|
|
||
| /** | ||
| * Param for how to handle invalid data (unseen labels or NULL values). | ||
|
|
@@ -79,20 +80,49 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi | |
| @Since("2.3.0") | ||
| def getStringOrderType: String = $(stringOrderType) | ||
|
|
||
| private[feature] def getInOutCols: (Array[String], Array[String]) = { | ||
|
|
||
| require((isSet(inputCol) && isSet(outputCol) && !isSet(inputCols) && !isSet(outputCols)) || | ||
| (!isSet(inputCol) && !isSet(outputCol) && isSet(inputCols) && isSet(outputCols)), | ||
| "Only allow to set either inputCol/outputCol, or inputCols/outputCols" | ||
| ) | ||
|
|
||
| if (isSet(inputCol)) { | ||
| (Array($(inputCol)), Array($(outputCol))) | ||
| } else { | ||
| require($(inputCols).length == $(outputCols).length, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should add a test case for this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test added. |
||
| "inputCols number do not match outputCols") | ||
| ($(inputCols), $(outputCols)) | ||
| } | ||
| } | ||
|
|
||
| /** Validates and transforms the input schema. */ | ||
| protected def validateAndTransformSchema(schema: StructType): StructType = { | ||
| val inputColName = $(inputCol) | ||
| val inputDataType = schema(inputColName).dataType | ||
| require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], | ||
| s"The input column $inputColName must be either string type or numeric type, " + | ||
| s"but got $inputDataType.") | ||
| protected def validateAndTransformSchema(schema: StructType, | ||
| skipNonExistsCol: Boolean = false): StructType = { | ||
|
|
||
| val (inputColNames, outputColNames) = getInOutCols | ||
| val inputFields = schema.fields | ||
| val outputColName = $(outputCol) | ||
| require(inputFields.forall(_.name != outputColName), | ||
| s"Output column $outputColName already exists.") | ||
| val attr = NominalAttribute.defaultAttr.withName($(outputCol)) | ||
| val outputFields = inputFields :+ attr.toStructField() | ||
| StructType(outputFields) | ||
| val outputFields = for (i <- 0 until inputColNames.length) yield { | ||
| val inputColName = inputColNames(i) | ||
| if (schema.fieldNames.contains(inputColName)) { | ||
| val inputDataType = schema(inputColName).dataType | ||
| require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], | ||
| s"The input column $inputColName must be either string type or numeric type, " + | ||
| s"but got $inputDataType.") | ||
| val outputColName = outputColNames(i) | ||
| require(inputFields.forall(_.name != outputColName), | ||
| s"Output column $outputColName already exists.") | ||
| val attr = NominalAttribute.defaultAttr.withName($(outputCol)) | ||
| attr.toStructField() | ||
| } else { | ||
| if (skipNonExistsCol) { | ||
| null | ||
| } else { | ||
| throw new SparkException(s"Input column ${inputColName} do not exist.") | ||
| } | ||
| } | ||
| } | ||
| StructType(inputFields ++ outputFields.filter(_ != null)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -130,21 +160,51 @@ class StringIndexer @Since("1.4.0") ( | |
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def fit(dataset: Dataset[_]): StringIndexerModel = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| val values = dataset.na.drop(Array($(inputCol))) | ||
| .select(col($(inputCol)).cast(StringType)) | ||
| .rdd.map(_.getString(0)) | ||
| val labels = $(stringOrderType) match { | ||
| case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) | ||
| .map(_._1).toArray | ||
| case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) | ||
| .map(_._1).toArray | ||
| case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) | ||
| case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) | ||
|
|
||
| val inputCols = getInOutCols._1 | ||
|
|
||
| val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]()) | ||
|
|
||
| val countByValueArray = dataset.na.drop(inputCols) | ||
| .select(inputCols.map(col(_).cast(StringType)): _*) | ||
| .rdd.treeAggregate(zeroState)( | ||
| (state: Array[OpenHashMap[String, Long]], row: Row) => { | ||
| for (i <- 0 until inputCols.length) { | ||
| state(i).changeValue(row.getString(i), 1L, _ + 1) | ||
| } | ||
| state | ||
| }, | ||
| (state1: Array[OpenHashMap[String, Long]], state2: Array[OpenHashMap[String, Long]]) => { | ||
| for (i <- 0 until inputCols.length) { | ||
| state2(i).foreach { case (key: String, count: Long) => | ||
| state1(i).changeValue(key, count, _ + count) | ||
| } | ||
| } | ||
| state1 | ||
| } | ||
| ) | ||
| val labelsArray = countByValueArray.map { countByValue => | ||
| $(stringOrderType) match { | ||
| case StringIndexer.frequencyDesc => | ||
| countByValue.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray | ||
| case StringIndexer.frequencyAsc => | ||
| countByValue.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray | ||
| case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray | ||
| case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but will aggregate count bring apparent overhead ? I don't want the code including too many
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the dataset is large, it might be. We can leave it as it is. If we find it is a bottleneck, we still can revisit it. |
||
| } | ||
| } | ||
| copyValues(new StringIndexerModel(uid, labels).setParent(this)) | ||
| copyValues(new StringIndexerModel(uid, labelsArray).setParent(this)) | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
|
|
@@ -177,7 +237,8 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { | |
| /** | ||
| * Model fitted by [[StringIndexer]]. | ||
| * | ||
| * @param labels Ordered list of labels, corresponding to indices to be assigned. | ||
| * @param labelsArray Array of ordered list of labels, corresponding to indices to be assigned | ||
| * for each input column. | ||
| * | ||
| * @note During transformation, if the input column does not exist, | ||
| * `StringIndexerModel.transform` would return the input dataset unmodified. | ||
|
|
@@ -186,23 +247,36 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { | |
| @Since("1.4.0") | ||
| class StringIndexerModel ( | ||
| @Since("1.4.0") override val uid: String, | ||
| @Since("1.5.0") val labels: Array[String]) | ||
| @Since("2.3.0") val labelsArray: Array[Array[String]]) | ||
| extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { | ||
|
|
||
| import StringIndexerModel._ | ||
|
|
||
| @Since("1.5.0") | ||
| def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) | ||
|
|
||
| private val labelToIndex: OpenHashMap[String, Double] = { | ||
| val n = labels.length | ||
| val map = new OpenHashMap[String, Double](n) | ||
| var i = 0 | ||
| while (i < n) { | ||
| map.update(labels(i), i) | ||
| i += 1 | ||
| def this(labels: Array[String]) = | ||
| this(Identifiable.randomUID("strIdx"), Array(labels)) | ||
|
|
||
| @Since("1.5.0") | ||
| def labels: Array[String] = { | ||
| require(labelsArray.length == 1) | ||
| labelsArray(0) | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| def this(labelsArray: Array[Array[String]]) = | ||
| this(Identifiable.randomUID("strIdx"), labelsArray) | ||
|
|
||
| private val labelToIndexArray: Array[OpenHashMap[String, Double]] = { | ||
| for (labels <- labelsArray) yield { | ||
| val n = labels.length | ||
| val map = new OpenHashMap[String, Double](n) | ||
| var i = 0 | ||
| while (i < n) { | ||
| map.update(labels(i), i) | ||
| i += 1 | ||
| } | ||
| map | ||
| } | ||
| map | ||
| } | ||
|
|
||
| /** @group setParam */ | ||
|
|
@@ -217,69 +291,100 @@ class StringIndexerModel ( | |
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| if (!dataset.schema.fieldNames.contains($(inputCol))) { | ||
| logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + | ||
| "Skip StringIndexerModel.") | ||
| return dataset.toDF | ||
| } | ||
| transformSchema(dataset.schema, logging = true) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can skip
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated. |
||
|
|
||
| val filteredLabels = getHandleInvalid match { | ||
| case StringIndexer.KEEP_INVALID => labels :+ "__unknown" | ||
| case _ => labels | ||
| } | ||
| var (inputColNames, outputColNames) = getInOutCols | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(filteredLabels).toMetadata() | ||
| val outputColumns = new Array[Column](outputColNames.length) | ||
|
|
||
| var filteredDataset = dataset | ||
| // If we are skipping invalid records, filter them out. | ||
| val (filteredDataset, keepInvalid) = getHandleInvalid match { | ||
| case StringIndexer.SKIP_INVALID => | ||
| if (getHandleInvalid == StringIndexer.SKIP_INVALID) { | ||
| filteredDataset = dataset.na.drop(inputColNames.filter( | ||
| dataset.schema.fieldNames.contains(_))) | ||
| for (i <- 0 until inputColNames.length) { | ||
| val inputColName = inputColNames(i) | ||
| val labelToIndex = labelToIndexArray(i) | ||
| val filterer = udf { label: String => | ||
| labelToIndex.contains(label) | ||
| } | ||
| (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) | ||
| case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) | ||
| filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) | ||
| } | ||
| } | ||
|
|
||
| val indexer = udf { label: String => | ||
| if (label == null) { | ||
| if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + | ||
| "NULLS, try setting StringIndexer.handleInvalid.") | ||
| } | ||
| for (i <- 0 until outputColNames.length) { | ||
| val inputColName = inputColNames(i) | ||
| val outputColName = outputColNames(i) | ||
| val labelToIndex = labelToIndexArray(i) | ||
| val labels = labelsArray(i) | ||
|
|
||
| if (!dataset.schema.fieldNames.contains(inputColName)) { | ||
| logInfo(s"Input column ${inputColName} does not exist during transformation. " + | ||
| "Skip this column StringIndexerModel transform.") | ||
| outputColNames(i) = null | ||
| } else { | ||
| if (labelToIndex.contains(label)) { | ||
| labelToIndex(label) | ||
| } else if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + | ||
| s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") | ||
| val filteredLabels = getHandleInvalid match { | ||
| case StringIndexer.KEEP_INVALID => labelsArray(i) :+ "__unknown" | ||
| case _ => labelsArray(i) | ||
| } | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName(outputColName).withValues(filteredLabels).toMetadata() | ||
|
|
||
| val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) | ||
|
|
||
| val indexer = udf { label: String => | ||
| if (label == null) { | ||
| if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + | ||
| "NULLS, try setting StringIndexer.handleInvalid.") | ||
| } | ||
| } else { | ||
| if (labelToIndex.contains(label)) { | ||
| labelToIndex(label) | ||
| } else if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + | ||
| s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") | ||
| } | ||
| } | ||
| }.asNondeterministic() | ||
|
|
||
| outputColumns(i) = indexer(dataset(inputColName).cast(StringType)) | ||
| .as(outputColName, metadata) | ||
| } | ||
| }.asNondeterministic() | ||
| } | ||
| val filteredOutputColNames = outputColNames.filter(_ != null) | ||
| val filteredOutputColumns = outputColumns.filter(_ != null) | ||
|
|
||
| filteredDataset.select(col("*"), | ||
| indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) | ||
| if (filteredOutputColNames.length > 0) { | ||
| filteredDataset.withColumns(filteredOutputColNames, filteredOutputColumns) | ||
| } else { | ||
| filteredDataset.toDF() | ||
| } | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| if (schema.fieldNames.contains($(inputCol))) { | ||
| validateAndTransformSchema(schema) | ||
| } else { | ||
| // If the input column does not exist during transformation, we skip StringIndexerModel. | ||
| schema | ||
| } | ||
| validateAndTransformSchema(schema, skipNonExistsCol = true) | ||
| } | ||
|
|
||
| @Since("1.4.1") | ||
| override def copy(extra: ParamMap): StringIndexerModel = { | ||
| val copied = new StringIndexerModel(uid, labels) | ||
| val copied = new StringIndexerModel(uid, labelsArray) | ||
| copyValues(copied, extra).setParent(parent) | ||
| } | ||
|
|
||
|
|
@@ -293,11 +398,11 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { | |
| private[StringIndexerModel] | ||
| class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter { | ||
|
|
||
| private case class Data(labels: Array[String]) | ||
| private case class Data(labelsArray: Array[Array[String]]) | ||
|
|
||
| override protected def saveImpl(path: String): Unit = { | ||
| DefaultParamsWriter.saveMetadata(instance, path, sc) | ||
| val data = Data(instance.labels) | ||
| val data = Data(instance.labelsArray) | ||
| val dataPath = new Path(path, "data").toString | ||
| sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) | ||
| } | ||
|
|
@@ -310,11 +415,22 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { | |
| override def load(path: String): StringIndexerModel = { | ||
| val metadata = DefaultParamsReader.loadMetadata(path, sc, className) | ||
| val dataPath = new Path(path, "data").toString | ||
| val data = sparkSession.read.parquet(dataPath) | ||
| .select("labels") | ||
| .head() | ||
| val labels = data.getAs[Seq[String]](0).toArray | ||
| val model = new StringIndexerModel(metadata.uid, labels) | ||
|
|
||
| val (majorVersion, minorVersion) = majorMinorVersion(metadata.sparkVersion) | ||
| val labelsArray = if (majorVersion < 2 || (majorVersion == 2 && minorVersion <= 2)) { | ||
| // Spark 2.2 and before | ||
| val data = sparkSession.read.parquet(dataPath) | ||
| .select("labels") | ||
| .head() | ||
| val labels = data.getAs[Seq[String]](0).toArray | ||
| Array(labels) | ||
| } else { | ||
| val data = sparkSession.read.parquet(dataPath) | ||
| .select("labelsArray") | ||
| .head() | ||
| data.getAs[Seq[Seq[String]]](0).map(_.toArray).toArray | ||
| } | ||
| val model = new StringIndexerModel(metadata.uid, labelsArray) | ||
| DefaultParamsReader.getAndSetParams(model, metadata) | ||
| model | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe match the language for the exception message here?
StringIndexer only supports setting either inputCol/outputCol or inputCols/outputCols