Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
import org.apache.spark.sql.types.StructType;
// $example off$

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Scala example?

Copy link
Member Author

@viirya viirya Oct 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a Scala example.

/**
* An example for Bucketizer.
* Run with
* <pre>
* bin/run-example ml.JavaBucketizerExample
* </pre>
*/
public class JavaBucketizerExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
Expand Down Expand Up @@ -68,6 +75,40 @@ public static void main(String[] args) {
bucketedData.show();
// $example off$

// $example on$
// Bucketize multiple columns at one pass.
double[][] splitsArray = {
{Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY},
{Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY}
};

List<Row> data2 = Arrays.asList(
RowFactory.create(-999.9, -999.9),
RowFactory.create(-0.5, -0.2),
RowFactory.create(-0.3, -0.1),
RowFactory.create(0.0, 0.0),
RowFactory.create(0.2, 0.4),
RowFactory.create(999.9, 999.9)
);
StructType schema2 = new StructType(new StructField[]{
new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features2", DataTypes.DoubleType, false, Metadata.empty())
});
Dataset<Row> dataFrame2 = spark.createDataFrame(data2, schema2);

Bucketizer bucketizer2 = new Bucketizer()
.setInputCols(new String[] {"features1", "features2"})
.setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"})
.setSplitsArray(splitsArray);
// Transform original data into its bucket index.
Dataset<Row> bucketedData2 = bucketizer2.transform(dataFrame2);

System.out.println("Bucketizer output with [" +
(bucketizer2.getSplitsArray()[0].length-1) + ", " +
(bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column");
bucketedData2.show();
// $example off$

spark.stop();
}
}
Expand Down
117 changes: 103 additions & 14 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,21 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
* `Bucketizer` can also map multiple columns at once.
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
with DefaultParamsWritable {
with HasInputCols with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("bucketizer"))
Expand Down Expand Up @@ -96,9 +97,63 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it clear that in the multi column case, the invalid handling is applied to all columns (so for error it will throw the error if any invalids are found in any column, for skip it will skip rows with any invalids in any column, etc)

setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

/**
* Parameter for specifying multiple splits parameters. Each element in this array can be used to
* map continuous features into buckets.
*
* @group param
*/
@Since("2.3.0")
val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
"The array of split points for mapping continuous features into buckets for multiple " +
"columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " +
"splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " +
"The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
"explicitly provided to cover all Double values; otherwise, values outside the splits " +
"specified will be treated as errors.",
Bucketizer.checkSplitsArray)

/**
* Param for output column names.
* @group param
*/
@Since("2.3.0")
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we making this final (and not others)? (also the getOutputCols?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess similarly to shared params? I think it makes sense to add a shared param since this, Imputer and others will use it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think the final is copied from previous multiple bucketizer trait. I'll remove it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create HasOutputCols.

"output column names")

/** @group getParam */
@Since("2.3.0")
def getSplitsArray: Array[Array[Double]] = $(splitsArray)

/** @group getParam */
@Since("2.3.0")
final def getOutputCols: Array[String] = $(outputCols)

/** @group setParam */
@Since("2.3.0")
def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)

/** @group setParam */
@Since("2.3.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Determines whether this `Bucketizer` is going to map multiple columns. Only if all necessary
* params for bucketizing multiple columns are set, we go for the path to map multiple columns.
* By default `Bucketizer` just maps a column of continuous features.
*/
private[ml] def isBucketizeMultipleInputCols(): Boolean = {
isSet(inputCols) && isSet(splitsArray) && isSet(outputCols)
}

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)

val (filteredDataset, keepInvalid) = {
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
// "skip" NaN option is set, will filter out NaN values in the dataset
Expand All @@ -108,26 +163,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
}
}

val bucketizer: UserDefinedFunction = udf { (feature: Double) =>
Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
}.withName("bucketizer")
val seqOfSplits = if (isBucketizeMultipleInputCols()) {
$(splitsArray).toSeq
Copy link
Contributor

@tengpeng tengpeng Nov 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am interested in the difference between .toSeq and Seq().

} else {
Seq($(splits))
}

val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) =>
udf { (feature: Double) =>
Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
}.withName(s"bucketizer_$idx")
}

val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
val newField = prepOutputField(filteredDataset.schema)
filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
val (inputColumns, outputColumns) = if (isBucketizeMultipleInputCols()) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
}
val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
}
val newFields = outputColumns.zipWithIndex.map { case (outputCol, idx) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we not done this already in transformSchema? Can we just re-use the result of that?

prepOutputField(seqOfSplits(idx), outputCol)
}
filteredDataset.withColumns(outputColumns, newCols, newFields.map(_.metadata))
}

private def prepOutputField(schema: StructType): StructField = {
val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {
val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
values = Some(buckets))
attr.toStructField()
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField(schema))
if (isBucketizeMultipleInputCols()) {
var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
}
transformedSchema
} else {
SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
}
}

@Since("1.4.1")
Expand Down Expand Up @@ -163,6 +245,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
}
}

/**
* Check each splits in the splits array.
*/
private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
splitsArray.forall(checkSplits(_))
}

/**
* Binary searching in several buckets to place each data point.
* @param splits array of split points
Expand Down
39 changes: 39 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
}
}

/**
* :: DeveloperApi ::
* Specialized version of `Param[Array[Array[Double]]]` for Java.
*/
@DeveloperApi
class DoubleArrayArrayParam(
parent: Params,
name: String,
doc: String,
isValid: Array[Array[Double]] => Boolean)
extends Param[Array[Array[Double]]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

/** Creates a param pair with a `java.util.List` of values (for Java and Python). */
def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] =
w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray)

override def jsonEncode(value: Array[Array[Double]]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode))))
}

override def jsonDecode(json: String): Array[Array[Double]] = {
parse(json) match {
case JArray(values) =>
values.map {
case JArray(values) =>
values.map(DoubleParam.jValueDecode).toArray
case _ =>
throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
}.toArray
case _ =>
throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
}
}
}

/**
* :: DeveloperApi ::
* Specialized version of `Param[Array[Int]]` for Java.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,39 @@ public void bucketizerTest() {
Assert.assertTrue((index >= 0) && (index <= 1));
}
}

@Test
public void bucketizerMultipleColumnsTest() {
double[][] splitsArray = {
{-0.5, 0.0, 0.5},
{-0.5, 0.0, 0.2, 0.5}
};

StructType schema = new StructType(new StructField[]{
new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()),
});
Dataset<Row> dataset = spark.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5, -0.5),
RowFactory.create(-0.3, -0.3),
RowFactory.create(0.0, 0.0),
RowFactory.create(0.2, 0.3)),
schema);

Bucketizer bucketizer = new Bucketizer()
.setInputCols(new String[] {"feature1", "feature2"})
.setOutputCols(new String[] {"result1", "result2"})
.setSplitsArray(splitsArray);

List<Row> result = bucketizer.transform(dataset).select("result1", "result2").collectAsList();

for (Row r : result) {
double index1 = r.getDouble(0);
Assert.assertTrue((index1 >= 0) && (index1 <= 1));

double index2 = r.getDouble(1);
Assert.assertTrue((index2 >= 0) && (index2 <= 2));
}
}
}
Loading