Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class StringIndexerModel (
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need change this line ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for picking this out! I changed this because I was matching on $(handleInvalid) in VectorAssembler and that seems to be the recommended way of doing this. Should I include this in the current PR and add a note or open a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it doesn't matter no need separate PR I think. just a minor change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, in general, I would not bother making changes like this. The one exception I do make is IntelliJ style complaints since those can be annoying for developers.

val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
Expand Down
119 changes: 102 additions & 17 deletions mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

package org.apache.spark.ml.feature

import java.util.NoSuchElementException

import scala.collection.mutable.ArrayBuilder
import scala.language.existentials

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand All @@ -36,7 +39,8 @@ import org.apache.spark.sql.types._
*/
@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
extends Transformer with HasInputCols with HasOutputCol with HasHandleInvalid
with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler"))
Expand All @@ -49,15 +53,73 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("1.6.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Since("2.4.0")

def setHandleInvalid(value: String): this.type = set(handleInvalid, value)

/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to expand this doc to explain the behavior: how various types of invalid values are treated (null, NaN, incorrect Vector length) and how computationally expensive different options can be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior of options already included, explanation of column length included here, run time information included in the VectorAssembler class's documentation. Thanks for the suggestion, this is super important!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, we just deal with nulls here. NaNs and incorrect length vectors are transmitted transparently. Do we need to test for those?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we deal with NaNs now. This PR is already dealing with some NaN cases: Dataset.na.drop handles NaNs in NumericType columns (but not VectorUDT columns).

I'm Ok with postponing incorrect vector lengths until later or doing that now since that work will be more separate.

* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
* output).
* Default: "error"
* @group param
*/
@Since("1.6.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"Hhow to handle invalid data (NULL values). Options are 'skip' (filter out rows with " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HHow -> How

"invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN " +
"in the * output).", ParamValidators.inArray(StringIndexer.supportedHandleInvalids))

setDefault(handleInvalid, StringIndexer.ERROR_INVALID)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
lazy val first = dataset.toDF.first()
val colsMissingNumAttrs = $(inputCols).filter { c =>
val field = schema(c)
field.dataType match {
case _: VectorUDT => AttributeGroup.fromStructField(field).numAttributes.isEmpty
case _ => false
}
}
if (dataset.isStreaming && colsMissingNumAttrs.nonEmpty) {
throw new RuntimeException(
s"""
|VectorAssembler cannot dynamically determine the size of vectors for streaming data.
|Consider applying VectorSizeHint to ${colsMissingNumAttrs.mkString("[", ", ", "]")}
|so that this transformer can be used to transform streaming inputs.
""".stripMargin)
}
val missingVectorSizes = colsMissingNumAttrs.map { c =>
$(handleInvalid) match {
case StringIndexer.ERROR_INVALID => c -> 0
case StringIndexer.SKIP_INVALID => c -> 0
case StringIndexer.KEEP_INVALID =>
try {
c -> dataset.select(c).na.drop().first.getAs[Vector](0).size
} catch {
case _: NoSuchElementException =>
throw new RuntimeException(
s"""
|VectorAssembler cannot determine the size of empty vectors. Consider applying
|VectorSizeHint to ${colsMissingNumAttrs.mkString("[", ", ", "]")} so that this
|transformer can be used to transform empty columns.
""".stripMargin)
}
}
}.toMap
val lengths = $(inputCols).map { c =>
val field = schema(c)
field.dataType match {
case _: NumericType | BooleanType => c -> 1 // DoubleType is also NumericType
case _: VectorUDT =>
c -> AttributeGroup.fromStructField(field).numAttributes.getOrElse(missingVectorSizes(c))
}
}.toMap
val attrs = $(inputCols).flatMap { c =>
val field = schema(c)
val index = schema.fieldIndex(c)
field.dataType match {
case DoubleType =>
val attr = Attribute.fromStructField(field)
Expand Down Expand Up @@ -85,18 +147,22 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
val numAttrs = lengths(c)
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()

val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID => (dataset.na.drop("any", $(inputCols)), false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can directly use dataset.na.drop($(inputCols))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point! Although I do think that keeping "any" might make it easier to read, but that may not necessarily hold for experienced people :P

case StringIndexer.KEEP_INVALID => (dataset, true)
case StringIndexer.ERROR_INVALID => (dataset, false)
}
// Data transformation.
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
VectorAssembler.assemble( $(inputCols).map(c => lengths(c)), keepInvalid)(r.toSeq: _*)
}.asNondeterministic()
val args = $(inputCols).map { c =>
schema(c).dataType match {
Expand All @@ -106,7 +172,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
}
}

dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
filteredDataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}

@Since("1.4.0")
Expand Down Expand Up @@ -136,34 +202,53 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {

private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)

@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)

private[feature] def assemble(vv: Any*): Vector = {
private[feature] def assemble(lengths: Array[Int], keepInvalid: Boolean)(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int]
val values = ArrayBuilder.make[Double]
var cur = 0
var featureIndex = 0

var inputColumnIndex = 0
vv.foreach {
case v: Double =>
if (v != 0.0) {
indices += cur
indices += featureIndex
values += v
}
cur += 1
inputColumnIndex += 1
featureIndex += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
indices += cur + i
indices += featureIndex + i
values += v
}
}
cur += vec.size
inputColumnIndex += 1
featureIndex += vec.size
case null =>
// TODO: output Double.NaN?
throw new SparkException("Values to assemble cannot be null.")
keepInvalid match {
case false => throw new SparkException("Values to assemble cannot be null.")
case true =>
val length: Int = lengths(inputColumnIndex)
Array.range(0, length).foreach { case (i) =>
indices += featureIndex + i
values += Double.NaN
}
inputColumnIndex += 1
featureIndex += length
}
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
Vectors.sparse(cur, indices.result(), values.result()).compressed
Vectors.sparse(featureIndex, indices.result(), values.result()).compressed
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,26 @@ class VectorAssemblerSuite

test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
assert(assemble(Array(1), true)(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
assert(assemble(Array(1, 1), true)(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
val dv = Vectors.dense(2.0, 0.0)
assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
assert(assemble(Array(1, 2, 1), true)(0.0, dv, 1.0) ===
Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
assert(assemble(0.0, dv, 1.0, sv) ===
assert(assemble(Array(1, 2, 1, 2), true)(0.0, dv, 1.0, sv) ===
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
for (v <- Seq(1, "a", null)) {
intercept[SparkException](assemble(v))
intercept[SparkException](assemble(1.0, v))
for (v <- Seq(1, "a")) {
intercept[SparkException](assemble(Array(1), true)(v))
intercept[SparkException](assemble(Array(1, 1), true)(1.0, v))
}
}

test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
val v1 = assemble(Array(1, 1, 1, 4), true)(0.0, 0.0, 0.0, Vectors.dense(4.0))
assert(v1.isInstanceOf[SparseVector])
val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
val sv = Vectors.sparse(1, Array(0), Array(4.0))
val v2 = assemble(Array(1, 1, 1, 1), true)(1.0, 2.0, 3.0, sv)
assert(v2.isInstanceOf[DenseVector])
}

Expand Down Expand Up @@ -147,4 +149,51 @@ class VectorAssemblerSuite
.filter(vectorUDF($"features") > 1)
.count() == 1)
}

test("assemble should keep nulls") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make more explicit: + " when keepInvalid = true"

import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(Array(1, 1), true)(1.0, null) === Vectors.dense(1.0, Double.NaN))
assert(assemble(Array(1, 2), true)(1.0, null) === Vectors.dense(1.0, Double.NaN, Double.NaN))
assert(assemble(Array(1), true)(null) === Vectors.dense(Double.NaN))
assert(assemble(Array(2), true)(null) === Vectors.dense(Double.NaN, Double.NaN))
}

test("assemble should throw errors") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly: make more explicit: + " when keepInvalid = false"

import org.apache.spark.ml.feature.VectorAssembler.assemble
intercept[SparkException](assemble(Array(1, 1), false)(1.0, null) ===
Vectors.dense(1.0, Double.NaN))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to compare with anything; just call assemble()

intercept[SparkException](assemble(Array(1, 2), false)(1.0, null) ===
Vectors.dense(1.0, Double.NaN, Double.NaN))
intercept[SparkException](assemble(Array(1), false)(null) === Vectors.dense(Double.NaN))
intercept[SparkException](assemble(Array(2), false)(null) ===
Vectors.dense(Double.NaN, Double.NaN))
}

test("Handle Invalid should behave properly") {
val df = Seq[(Long, Long, java.lang.Double, Vector, String, Vector, Long)](
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if there are "trash" columns not used by VectorAssembler, maybe name them as such and add a few null values in them for better testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, good idea! this helped me in catching the drop.na() bug that might drop everything

(1, 2, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 7L),
(2, 1, 0.0, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 6L),
(3, 3, null, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 8L),
(4, 4, null, null, "a", Vectors.sparse(2, Array(1), Array(3.0)), 9L)
).toDF("id1", "id2", "x", "y", "name", "z", "n")

val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z", "n"))
.setOutputCol("features")

assert(assembler.setHandleInvalid("skip").transform(df).count() == 1)
assert(assembler.setHandleInvalid("keep").transform(df).count() == 4)
assert(assembler.setHandleInvalid("keep").transform(df.sort("id2")).count() == 4)
intercept[SparkException](assembler.setHandleInvalid("error").transform(df).cache())

// all numeric columns are null
assert(assembler.setHandleInvalid("keep").transform(df.filter("id1==3")).count() == 1)

// all vector columns are null
val df2 = df.filter("0 == id1 % 2")
assert(assembler.setHandleInvalid("skip").transform(df2).count() == 0)
intercept[RuntimeException](assembler.setHandleInvalid("keep").transform(df2))
intercept[SparkException](assembler.setHandleInvalid("error").transform(df2).collect())
}

}