-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns #17819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
e8f5d89
38dce8b
6ff9c79
8386d1e
f8dedd1
7c38b77
92ef9bd
60d3ba1
f70fc2a
2abca6b
000844a
1889995
bb19708
a970723
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,20 +24,21 @@ import org.apache.spark.annotation.Since | |
| import org.apache.spark.ml.Model | ||
| import org.apache.spark.ml.attribute.NominalAttribute | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.expressions.UserDefinedFunction | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.{DoubleType, StructField, StructType} | ||
|
|
||
| /** | ||
| * `Bucketizer` maps a column of continuous features to a column of feature buckets. | ||
| * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, | ||
| * `Bucketizer` can also map multiple columns at once. | ||
| */ | ||
| @Since("1.4.0") | ||
| final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
| extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol | ||
| with DefaultParamsWritable { | ||
| with HasInputCols with DefaultParamsWritable { | ||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("bucketizer")) | ||
|
|
@@ -96,9 +97,63 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String | |
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make it clear that in the multi column case, the invalid handling is applied to all columns (so for |
||
| setDefault(handleInvalid, Bucketizer.ERROR_INVALID) | ||
|
|
||
| /** | ||
| * Parameter for specifying multiple splits parameters. Each element in this array can be used to | ||
| * map continuous features into buckets. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray", | ||
| "The array of split points for mapping continuous features into buckets for multiple " + | ||
| "columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " + | ||
| "splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " + | ||
| "The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " + | ||
| "explicitly provided to cover all Double values; otherwise, values outside the splits " + | ||
| "specified will be treated as errors.", | ||
| Bucketizer.checkSplitsArray) | ||
|
|
||
| /** | ||
| * Param for output column names. | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", | ||
|
||
| "output column names") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getSplitsArray: Array[Array[Double]] = $(splitsArray) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| final def getOutputCols: Array[String] = $(outputCols) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| /** | ||
| * Determines whether this `Bucketizer` is going to map multiple columns. Only if all necessary | ||
| * params for bucketizing multiple columns are set, we go for the path to map multiple columns. | ||
| * By default `Bucketizer` just maps a column of continuous features. | ||
| */ | ||
| private[ml] def isBucketizeMultipleInputCols(): Boolean = { | ||
| isSet(inputCols) && isSet(splitsArray) && isSet(outputCols) | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema) | ||
|
|
||
| val (filteredDataset, keepInvalid) = { | ||
| if (getHandleInvalid == Bucketizer.SKIP_INVALID) { | ||
| // "skip" NaN option is set, will filter out NaN values in the dataset | ||
|
|
@@ -108,26 +163,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String | |
| } | ||
| } | ||
|
|
||
| val bucketizer: UserDefinedFunction = udf { (feature: Double) => | ||
| Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) | ||
| }.withName("bucketizer") | ||
| val seqOfSplits = if (isBucketizeMultipleInputCols()) { | ||
| $(splitsArray).toSeq | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am interested in the difference between |
||
| } else { | ||
| Seq($(splits)) | ||
| } | ||
|
|
||
| val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) => | ||
| udf { (feature: Double) => | ||
| Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid) | ||
| }.withName(s"bucketizer_$idx") | ||
| } | ||
|
|
||
| val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) | ||
| val newField = prepOutputField(filteredDataset.schema) | ||
| filteredDataset.withColumn($(outputCol), newCol, newField.metadata) | ||
| val (inputColumns, outputColumns) = if (isBucketizeMultipleInputCols()) { | ||
| ($(inputCols).toSeq, $(outputCols).toSeq) | ||
| } else { | ||
| (Seq($(inputCol)), Seq($(outputCol))) | ||
| } | ||
| val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => | ||
| bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) | ||
| } | ||
| val newFields = outputColumns.zipWithIndex.map { case (outputCol, idx) => | ||
|
||
| prepOutputField(seqOfSplits(idx), outputCol) | ||
| } | ||
| filteredDataset.withColumns(outputColumns, newCols, newFields.map(_.metadata)) | ||
| } | ||
|
|
||
| private def prepOutputField(schema: StructType): StructField = { | ||
| val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray | ||
| val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), | ||
| private def prepOutputField(splits: Array[Double], outputCol: String): StructField = { | ||
| val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray | ||
| val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true), | ||
| values = Some(buckets)) | ||
| attr.toStructField() | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkNumericType(schema, $(inputCol)) | ||
| SchemaUtils.appendColumn(schema, prepOutputField(schema)) | ||
| if (isBucketizeMultipleInputCols()) { | ||
| var transformedSchema = schema | ||
| $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => | ||
| SchemaUtils.checkNumericType(transformedSchema, inputCol) | ||
| transformedSchema = SchemaUtils.appendColumn(transformedSchema, | ||
| prepOutputField($(splitsArray)(idx), outputCol)) | ||
| } | ||
| transformedSchema | ||
| } else { | ||
| SchemaUtils.checkNumericType(schema, $(inputCol)) | ||
| SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol))) | ||
| } | ||
| } | ||
|
|
||
| @Since("1.4.1") | ||
|
|
@@ -163,6 +245,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Check each splits in the splits array. | ||
| */ | ||
| private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = { | ||
| splitsArray.forall(checkSplits(_)) | ||
| } | ||
|
|
||
| /** | ||
| * Binary searching in several buckets to place each data point. | ||
| * @param splits array of split points | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No Scala example?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a Scala example.