Skip to content
17 changes: 12 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ private[ml] trait PredictorParams extends Params
* and put it in an RDD with strong types.
* Validate the output instances with the given function.
*/
protected def extractInstances(dataset: Dataset[_],
validateInstance: Instance => Unit): RDD[Instance] = {
protected def extractInstances(
dataset: Dataset[_],
validateInstance: Instance => Unit): RDD[Instance] = {
extractInstances(dataset).map { instance =>
validateInstance(instance)
instance
Expand Down Expand Up @@ -222,7 +223,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
protected def featuresDataType: DataType = new VectorUDT

override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, featuresDataType)
var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
}
outputSchema
}

/**
Expand All @@ -244,10 +249,12 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
}

protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Any) =>
val outputSchema = transformSchema(dataset.schema, logging = true)
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))),
outputSchema($(predictionCol)).metadata)
}

/**
Expand Down
5 changes: 3 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)
val transformUDF = udf(this.createTransformFunc, outputDataType)
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
outputSchema($(outputCol)).metadata)
}

override def copy(extra: ParamMap): T = defaultCopy(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ private[spark] trait ClassifierParams
* and put it in an RDD with strong types.
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*/
protected def extractInstances(dataset: Dataset[_],
numClasses: Int): RDD[Instance] = {
protected def extractInstances(
dataset: Dataset[_],
numClasses: Int): RDD[Instance] = {
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
Expand Down Expand Up @@ -183,6 +184,19 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
/** Number of classes (values which the label can take). */
def numClasses: Int

override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(schema,
$(predictionCol), numClasses)
}
if ($(rawPredictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(rawPredictionCol), numClasses)
}
outputSchema
}

/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
Expand All @@ -193,29 +207,31 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @return transformed dataset
*/
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = dataset
var numColsOutput = 0
if (getRawPredictionCol != "") {
val predictRawUDF = udf { (features: Any) =>
val predictRawUDF = udf { features: Any =>
predictRaw(features.asInstanceOf[FeaturesType])
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
outputSchema($(rawPredictionCol)).metadata)
numColsOutput += 1
}
if (getPredictionCol != "") {
val predUDF = if (getRawPredictionCol != "") {
val predCol = if (getRawPredictionCol != "") {
udf(raw2prediction _).apply(col(getRawPredictionCol))
} else {
val predictUDF = udf { (features: Any) =>
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
predictUDF(col(getFeaturesCol))
}
outputData = outputData.withColumn(getPredictionCol, predUDF)
outputData = outputData.withColumn(getPredictionCol, predCol,
outputSchema($(predictionCol)).metadata)
numColsOutput += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand Down Expand Up @@ -202,13 +203,23 @@ class DecisionTreeClassificationModel private[ml] (
rootNode.predictImpl(features).prediction
}

@Since("3.0.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
}
outputSchema
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
outputSchema($(leafCol)).metadata)
} else {
outputData
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType

/**
* Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
Expand Down Expand Up @@ -291,13 +292,23 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
}
outputSchema
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
outputSchema($(leafCol)).metadata)
} else {
outputData
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,23 @@ final class OneVsRestModel private[ml] (

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
var outputSchema = validateAndTransformSchema(schema, fitting = false,
getClassifier.featuresDataType)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
$(predictionCol), numClasses)
}
if ($(rawPredictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(rawPredictionCol), numClasses)
}
outputSchema
}

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
// Check schema
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) {
logWarning(s"$uid: OneVsRestModel.transform() does nothing" +
Expand Down Expand Up @@ -230,6 +240,7 @@ final class OneVsRestModel private[ml] (

predictionColNames :+= getRawPredictionCol
predictionColumns :+= rawPredictionUDF(col(accColName))
.as($(rawPredictionCol), outputSchema($(rawPredictionCol)).metadata)
}

if (getPredictionCol.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ abstract class ProbabilisticClassificationModel[
set(thresholds, value).asInstanceOf[M]
}

override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(probabilityCol).nonEmpty) {
outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
$(probabilityCol), numClasses)
}
outputSchema
}

/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
Expand All @@ -101,7 +110,7 @@ abstract class ProbabilisticClassificationModel[
* @return transformed dataset
*/
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
Expand All @@ -113,36 +122,39 @@ abstract class ProbabilisticClassificationModel[
var outputData = dataset
var numColsOutput = 0
if ($(rawPredictionCol).nonEmpty) {
val predictRawUDF = udf { (features: Any) =>
val predictRawUDF = udf { features: Any =>
predictRaw(features.asInstanceOf[FeaturesType])
}
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
outputSchema($(rawPredictionCol)).metadata)
numColsOutput += 1
}
if ($(probabilityCol).nonEmpty) {
val probUDF = if ($(rawPredictionCol).nonEmpty) {
val probCol = if ($(rawPredictionCol).nonEmpty) {
udf(raw2probability _).apply(col($(rawPredictionCol)))
} else {
val probabilityUDF = udf { (features: Any) =>
val probabilityUDF = udf { features: Any =>
predictProbability(features.asInstanceOf[FeaturesType])
}
probabilityUDF(col($(featuresCol)))
}
outputData = outputData.withColumn($(probabilityCol), probUDF)
outputData = outputData.withColumn($(probabilityCol), probCol,
outputSchema($(probabilityCol)).metadata)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
val predUDF = if ($(rawPredictionCol).nonEmpty) {
val predCol = if ($(rawPredictionCol).nonEmpty) {
udf(raw2prediction _).apply(col($(rawPredictionCol)))
} else if ($(probabilityCol).nonEmpty) {
udf(probability2prediction _).apply(col($(probabilityCol)))
} else {
val predictUDF = udf { (features: Any) =>
val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
predictUDF(col($(featuresCol)))
}
outputData = outputData.withColumn($(predictionCol), predUDF)
outputData = outputData.withColumn($(predictionCol), predCol,
outputSchema($(predictionCol)).metadata)
numColsOutput += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down Expand Up @@ -210,13 +211,23 @@ class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
var outputSchema = super.transformSchema(schema)
if ($(leafCol).nonEmpty) {
outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
}
outputSchema
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)

val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
outputSchema($(leafCol)).metadata)
} else {
outputData
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,21 @@ class BisectingKMeansModel private[ml] (

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val outputSchema = transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol),
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)),
outputSchema($(predictionCol)).metadata)
}

@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
var outputSchema = validateAndTransformSchema(schema)
if ($(predictionCol).nonEmpty) {
outputSchema = SchemaUtils.updateNumValues(outputSchema,
$(predictionCol), parentModel.k)
}
outputSchema
}

@Since("3.0.0")
Expand Down
Loading