diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 9eac8ed22a3f6..98dd692cbe55d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -88,8 +88,9 @@ private[ml] trait PredictorParams extends Params
* and put it in an RDD with strong types.
* Validate the output instances with the given function.
*/
- protected def extractInstances(dataset: Dataset[_],
- validateInstance: Instance => Unit): RDD[Instance] = {
+ protected def extractInstances(
+ dataset: Dataset[_],
+ validateInstance: Instance => Unit): RDD[Instance] = {
extractInstances(dataset).map { instance =>
validateInstance(instance)
instance
@@ -222,7 +223,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
protected def featuresDataType: DataType = new VectorUDT
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = false, featuresDataType)
+ var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
+ }
+ outputSchema
}
/**
@@ -244,10 +249,12 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
}
protected def transformImpl(dataset: Dataset[_]): DataFrame = {
- val predictUDF = udf { (features: Any) =>
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+ val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))),
+ outputSchema($(predictionCol)).metadata)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index a3a2b55adc25d..7874fc29db6c8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -117,9 +117,10 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
}
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val transformUDF = udf(this.createTransformFunc, outputDataType)
- dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
+ dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
override def copy(extra: ParamMap): T = defaultCopy(extra)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 3bff236677e6b..be552e6be0b50 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -48,8 +48,9 @@ private[spark] trait ClassifierParams
* and put it in an RDD with strong types.
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*/
- protected def extractInstances(dataset: Dataset[_],
- numClasses: Int): RDD[Instance] = {
+ protected def extractInstances(
+ dataset: Dataset[_],
+ numClasses: Int): RDD[Instance] = {
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
@@ -183,6 +184,19 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
/** Number of classes (values which the label can take). */
def numClasses: Int
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumValues(schema,
+ $(predictionCol), numClasses)
+ }
+ if ($(rawPredictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(rawPredictionCol), numClasses)
+ }
+ outputSchema
+ }
+
/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
@@ -193,29 +207,31 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @return transformed dataset
*/
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
var outputData = dataset
var numColsOutput = 0
if (getRawPredictionCol != "") {
- val predictRawUDF = udf { (features: Any) =>
+ val predictRawUDF = udf { features: Any =>
predictRaw(features.asInstanceOf[FeaturesType])
}
- outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
+ outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
+ outputSchema($(rawPredictionCol)).metadata)
numColsOutput += 1
}
if (getPredictionCol != "") {
- val predUDF = if (getRawPredictionCol != "") {
+ val predCol = if (getRawPredictionCol != "") {
udf(raw2prediction _).apply(col(getRawPredictionCol))
} else {
- val predictUDF = udf { (features: Any) =>
+ val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
predictUDF(col(getFeaturesCol))
}
- outputData = outputData.withColumn(getPredictionCol, predUDF)
+ outputData = outputData.withColumn(getPredictionCol, predCol,
+ outputSchema($(predictionCol)).metadata)
numColsOutput += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index e02109375373e..d10f684f0dcf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructType
/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -202,13 +203,23 @@ class DecisionTreeClassificationModel private[ml] (
rootNode.predictImpl(features).prediction
}
+ @Since("3.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
- outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
+ outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
+ outputSchema($(leafCol)).metadata)
} else {
outputData
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index e1f5338f34899..6e54e0f15b85c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
/**
* Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting)
@@ -291,13 +292,23 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
+ @Since("1.6.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
- outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
+ outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
+ outputSchema($(leafCol)).metadata)
} else {
outputData
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 51a624795cdd4..bbf8e8fc90ad5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -161,13 +161,23 @@ final class OneVsRestModel private[ml] (
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
+ var outputSchema = validateAndTransformSchema(schema, fitting = false,
+ getClassifier.featuresDataType)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumValues(outputSchema,
+ $(predictionCol), numClasses)
+ }
+ if ($(rawPredictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(rawPredictionCol), numClasses)
+ }
+ outputSchema
}
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
// Check schema
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) {
logWarning(s"$uid: OneVsRestModel.transform() does nothing" +
@@ -230,6 +240,7 @@ final class OneVsRestModel private[ml] (
predictionColNames :+= getRawPredictionCol
predictionColumns :+= rawPredictionUDF(col(accColName))
+ .as($(rawPredictionCol), outputSchema($(rawPredictionCol)).metadata)
}
if (getPredictionCol.nonEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 2171ac335e7b8..2e4d69330d132 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -90,6 +90,15 @@ abstract class ProbabilisticClassificationModel[
set(thresholds, value).asInstanceOf[M]
}
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(probabilityCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(probabilityCol), numClasses)
+ }
+ outputSchema
+ }
+
/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
@@ -101,7 +110,7 @@ abstract class ProbabilisticClassificationModel[
* @return transformed dataset
*/
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
@@ -113,36 +122,39 @@ abstract class ProbabilisticClassificationModel[
var outputData = dataset
var numColsOutput = 0
if ($(rawPredictionCol).nonEmpty) {
- val predictRawUDF = udf { (features: Any) =>
+ val predictRawUDF = udf { features: Any =>
predictRaw(features.asInstanceOf[FeaturesType])
}
- outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)))
+ outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
+ outputSchema($(rawPredictionCol)).metadata)
numColsOutput += 1
}
if ($(probabilityCol).nonEmpty) {
- val probUDF = if ($(rawPredictionCol).nonEmpty) {
+ val probCol = if ($(rawPredictionCol).nonEmpty) {
udf(raw2probability _).apply(col($(rawPredictionCol)))
} else {
- val probabilityUDF = udf { (features: Any) =>
+ val probabilityUDF = udf { features: Any =>
predictProbability(features.asInstanceOf[FeaturesType])
}
probabilityUDF(col($(featuresCol)))
}
- outputData = outputData.withColumn($(probabilityCol), probUDF)
+ outputData = outputData.withColumn($(probabilityCol), probCol,
+ outputSchema($(probabilityCol)).metadata)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
- val predUDF = if ($(rawPredictionCol).nonEmpty) {
+ val predCol = if ($(rawPredictionCol).nonEmpty) {
udf(raw2prediction _).apply(col($(rawPredictionCol)))
} else if ($(probabilityCol).nonEmpty) {
udf(probability2prediction _).apply(col($(probabilityCol)))
} else {
- val predictUDF = udf { (features: Any) =>
+ val predictUDF = udf { features: Any =>
predict(features.asInstanceOf[FeaturesType])
}
predictUDF(col($(featuresCol)))
}
- outputData = outputData.withColumn($(predictionCol), predUDF)
+ outputData = outputData.withColumn($(predictionCol), predCol,
+ outputSchema($(predictionCol)).metadata)
numColsOutput += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index bc28d783ed962..f88fc2a6a0914 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructType
/**
* Random Forest learning algorithm for
@@ -210,13 +211,23 @@ class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val outputData = super.transform(dataset)
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
- outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))))
+ outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))),
+ outputSchema($(leafCol)).metadata)
} else {
outputData
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 5f2316fa7ce18..79760d69489c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -110,15 +110,21 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol),
- predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
+ predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)),
+ outputSchema($(predictionCol)).metadata)
}
@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumValues(outputSchema,
+ $(predictionCol), parentModel.k)
+ }
+ outputSchema
}
@Since("3.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 916f326ab5615..9d00d6a8bcbe4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -112,7 +112,7 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
var outputData = dataset
@@ -120,17 +120,20 @@ class GaussianMixtureModel private[ml] (
if ($(probabilityCol).nonEmpty) {
val probUDF = udf((vector: Vector) => predictProbability(vector))
- outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol))
+ outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol),
+ outputSchema($(probabilityCol)).metadata)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
if ($(probabilityCol).nonEmpty) {
val predUDF = udf((vector: Vector) => vector.argmax)
- outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol))))
+ outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol))),
+ outputSchema($(predictionCol)).metadata)
} else {
val predUDF = udf((vector: Vector) => predict(vector))
- outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol))
+ outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol),
+ outputSchema($(predictionCol)).metadata)
}
numColsOutput += 1
}
@@ -144,7 +147,16 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumValues(outputSchema,
+ $(predictionCol), weights.length)
+ }
+ if ($(probabilityCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(probabilityCol), weights.length)
+ }
+ outputSchema
}
@Since("3.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index caeded400f9aa..5cbba6c77f9fb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -127,17 +127,23 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol),
- predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
+ predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)),
+ outputSchema($(predictionCol)).metadata)
}
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumValues(outputSchema,
+ $(predictionCol), parentModel.k)
+ }
+ outputSchema
}
@Since("3.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 9b0005b3747dc..e30be8c20dcc3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -459,12 +459,13 @@ abstract class LDAModel private[ml] (
*/
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val func = getTopicDistributionMethod
val transformer = udf(func)
dataset.withColumn($(topicDistributionCol),
- transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol)))
+ transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol)),
+ outputSchema($(topicDistributionCol)).metadata)
}
/**
@@ -504,7 +505,12 @@ abstract class LDAModel private[ml] (
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(topicDistributionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(topicDistributionCol), oldLocalModel.k)
+ }
+ outputSchema
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 07a4f91443bc5..381ab6eb7355b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuilder
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.attribute.BinaryAttribute
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -193,7 +193,12 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
case DoubleType =>
BinaryAttribute.defaultAttr.withName(outputColName).toStructField()
case _: VectorUDT =>
- StructField(outputColName, new VectorUDT)
+ val size = AttributeGroup.fromStructField(schema(inputColName)).size
+ if (size < 0) {
+ StructField(outputColName, new VectorUDT)
+ } else {
+ new AttributeGroup(outputColName, numAttributes = size).toStructField()
+ }
case _ =>
throw new IllegalArgumentException(s"Data type $inputType is not supported.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 9103e4feac454..76f4f944f11d5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -289,8 +289,7 @@ final class ChiSqSelectorModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val newField = prepOutputField(schema)
- val outputFields = schema.fields :+ newField
- StructType(outputFields)
+ SchemaUtils.appendColumn(schema, newField)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index c58d44d492342..7ba6f640b1e49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -300,7 +300,7 @@ class CountVectorizerModel(
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap
broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict))
@@ -326,14 +326,19 @@ class CountVectorizerModel(
Vectors.sparse(dictBr.value.size, effectiveCounts)
}
- val attrs = vocabulary.map(_ => new NumericAttribute).asInstanceOf[Array[Attribute]]
- val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
- dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), metadata)
+ dataset.withColumn($(outputCol), vectorizer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ val attrs: Array[Attribute] = vocabulary.map(_ => new NumericAttribute)
+ val field = new AttributeGroup($(outputCol), attrs).toStructField()
+ outputSchema = SchemaUtils.updateField(outputSchema, field)
+ }
+ outputSchema
}
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
index e2167f01281da..d057e5a62e507 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
@@ -21,10 +21,11 @@ import org.jtransforms.dct._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.BooleanParam
import org.apache.spark.ml.util._
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
/**
* A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero
@@ -75,6 +76,18 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String)
override protected def outputDataType: DataType = new VectorUDT
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
+ val size = AttributeGroup.fromStructField(schema($(inputCol))).size
+ if (size >= 0) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), size)
+ }
+ }
+ outputSchema
+ }
+
@Since("3.0.0")
override def toString: String = {
s"DCT: uid=$uid, inverse=$inverse"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 227c13d60fd8f..3b328f2fd8cee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -21,10 +21,10 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.Param
-import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature.{ElementwiseProduct => OldElementwiseProduct}
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
/**
* Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a
@@ -82,6 +82,15 @@ class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
override protected def outputDataType: DataType = new VectorUDT()
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), $(scalingVec).size)
+ }
+ outputSchema
+ }
+
@Since("3.0.0")
override def toString: String = {
s"ElementwiseProduct: uid=$uid" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 5f4103abcf50f..e6f124ef7d666 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -131,7 +131,7 @@ class IDFModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val func = { vector: Vector =>
vector match {
@@ -149,12 +149,18 @@ class IDFModel private[ml] (
}
val transformer = udf(func)
- dataset.withColumn($(outputCol), transformer(col($(inputCol))))
+ dataset.withColumn($(outputCol), transformer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), idf.size)
+ }
+ outputSchema
}
@Since("1.4.1")
@@ -180,7 +186,7 @@ class IDFModel private[ml] (
@Since("3.0.0")
override def toString: String = {
- s"IDFModel: uid=$uid, numDocs=$numDocs"
+ s"IDFModel: uid=$uid, numDocs=$numDocs, numFeatures=${idf.size}"
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 6bab70e502ed7..2d48a5f9f4915 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -117,19 +117,25 @@ class MaxAbsScalerModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val scale = maxAbs.toArray.map { v => if (v == 0) 1.0 else 1 / v }
val func = StandardScalerModel.getTransformFunc(
Array.empty, scale, false, true)
val transformer = udf(func)
- dataset.withColumn($(outputCol), transformer(col($(inputCol))))
+ dataset.withColumn($(outputCol), transformer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
@Since("2.0.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), maxAbs.size)
+ }
+ outputSchema
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index e381a0435e9eb..c84892c974b90 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -174,7 +174,7 @@ class MinMaxScalerModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val numFeatures = originalMax.size
val scale = $(max) - $(min)
@@ -210,12 +210,18 @@ class MinMaxScalerModel private[ml] (
Vectors.dense(values).compressed
}
- dataset.withColumn($(outputCol), transformer(col($(inputCol))))
+ dataset.withColumn($(outputCol), transformer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), originalMin.size)
+ }
+ outputSchema
}
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index d129c2b2c2dc1..4c7583b8381dc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -19,12 +19,13 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
/**
* Normalize a vector to have unit norm using the given p-norm.
@@ -66,6 +67,19 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
override protected def outputDataType: DataType = new VectorUDT()
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
+ val size = AttributeGroup.fromStructField(schema($(inputCol))).size
+ if (size >= 0) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), size)
+ }
+ }
+ outputSchema
+ }
+
@Since("3.0.0")
override def toString: String = {
s"Normalizer: uid=$uid, p=${$(p)}"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 27a3854d39b47..9eeb4c8ca2506 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{DenseMatrix => OldDenseMatrix, Vectors => OldVectors}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.VersionUtils.majorVersion
/**
@@ -52,10 +52,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
require(!schema.fieldNames.contains($(outputCol)),
s"Output column ${$(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
- StructType(outputFields)
+ SchemaUtils.updateAttributeGroupSize(schema, $(outputCol), $(k))
}
-
}
/**
@@ -145,16 +143,22 @@ class PCAModel private[ml] (
*/
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val transposed = pc.transpose
val transformer = udf { vector: Vector => transposed.multiply(vector) }
- dataset.withColumn($(outputCol), transformer(col($(inputCol))))
+ dataset.withColumn($(outputCol), transformer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), $(k))
+ }
+ outputSchema
}
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
index 1b9b8082931a5..f02ef0dfc70f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
@@ -227,7 +227,7 @@ class RobustScalerModel private[ml] (
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val shift = if ($(withCentering)) median.toArray else Array.emptyDoubleArray
val scale = if ($(withScaling)) {
@@ -238,11 +238,17 @@ class RobustScalerModel private[ml] (
shift, scale, $(withCentering), $(withScaling))
val transformer = udf(func)
- dataset.withColumn($(outputCol), transformer(col($(inputCol))))
+ dataset.withColumn($(outputCol), transformer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), median.size)
+ }
+ outputSchema
}
override def copy(extra: ParamMap): RobustScalerModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 8d4d4197e7db2..c6b1b29a6d9bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -157,7 +157,7 @@ class StandardScalerModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val shift = if ($(withMean)) mean.toArray else Array.emptyDoubleArray
val scale = if ($(withStd)) {
std.toArray.map { v => if (v == 0) 0.0 else 1.0 / v }
@@ -166,12 +166,18 @@ class StandardScalerModel private[ml] (
val func = getTransformFunc(shift, scale, $(withMean), $(withStd))
val transformer = udf(func)
- dataset.withColumn($(outputCol), transformer(col($(inputCol))))
+ dataset.withColumn($(outputCol), transformer(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), mean.size)
+ }
+ outputSchema
}
@Since("1.4.1")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index b84b8af4e8a94..45bb4b8e6e65d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -153,8 +153,7 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
}
val numFeaturesSelected = $(indices).length + $(names).length
val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
- val outputFields = schema.fields :+ outputAttr.toStructField()
- StructType(outputFields)
+ SchemaUtils.appendColumn(schema, outputAttr.toStructField)
}
@Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 81dde0315c190..bbfcbfbe038ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -288,15 +288,16 @@ class Word2VecModel private[ml] (
*/
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
.map(identity) // mapValues doesn't return a serializable map (SI-7005)
val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
val d = $(vectorSize)
+ val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.isEmpty) {
- Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
+ emptyVec
} else {
val sum = Vectors.zeros(d)
sentence.foreach { word =>
@@ -308,12 +309,18 @@ class Word2VecModel private[ml] (
sum
}
}
- dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))
+ dataset.withColumn($(outputCol), word2Vec(col($(inputCol))),
+ outputSchema($(outputCol)).metadata)
}
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ var outputSchema = validateAndTransformSchema(schema)
+ if ($(outputCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(outputCol), $(vectorSize))
+ }
+ outputSchema
}
@Since("1.4.1")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index faf77252cb738..7079ac89dcc93 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -349,7 +349,7 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]
@@ -358,12 +358,14 @@ class AFTSurvivalRegressionModel private[ml] (
val predictUDF = udf { features: Vector => predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
+ .as($(predictionCol), outputSchema($(predictionCol)).metadata)
}
if (hasQuantilesCol) {
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
predictionColNames :+= $(quantilesCol)
predictionColumns :+= predictQuantilesUDF(col($(featuresCol)))
+ .as($(quantilesCol), outputSchema($(quantilesCol)).metadata)
}
if (predictionColNames.nonEmpty) {
@@ -377,7 +379,15 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = false)
+ var outputSchema = validateAndTransformSchema(schema, fitting = false)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
+ }
+ if (isDefined(quantilesCol) && $(quantilesCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
+ $(quantilesCol), $(quantileProbabilities).length)
+ }
+ outputSchema
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 4a97997a1deb8..447e6f90a44e7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -36,7 +36,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.types.StructType
/**
* Decision tree
@@ -202,9 +202,21 @@ class DecisionTreeRegressionModel private[ml] (
rootNode.predictImpl(features).impurityStats.calculate()
}
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumeric(outputSchema, $(varianceCol))
+ }
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]
@@ -213,18 +225,21 @@ class DecisionTreeRegressionModel private[ml] (
val predictUDF = udf { features: Vector => predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
+ .as($(predictionCol), outputSchema($(predictionCol)).metadata)
}
if (isDefined(varianceCol) && $(varianceCol).nonEmpty) {
val predictVarianceUDF = udf { features: Vector => predictVariance(features) }
predictionColNames :+= $(varianceCol)
predictionColumns :+= predictVarianceUDF(col($(featuresCol)))
+ .as($(varianceCol), outputSchema($(varianceCol)).metadata)
}
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => predictLeaf(features) }
predictionColNames :+= $(leafCol)
predictionColumns :+= leafUDF(col($(featuresCol)))
+ .as($(leafCol), outputSchema($(leafCol)).metadata)
}
if (predictionColNames.nonEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 700f7a2075a91..eb0f2362af570 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -35,6 +35,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
/**
* Gradient-Boosted Trees (GBTs)
@@ -255,8 +256,17 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]
@@ -267,12 +277,14 @@ class GBTRegressionModel private[ml](
val predictUDF = udf { features: Vector => bcastModel.value.predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
+ .as($(featuresCol), outputSchema($(featuresCol)).metadata)
}
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) }
predictionColNames :+= $(leafCol)
predictionColumns :+= leafUDF(col($(featuresCol)))
+ .as($(leafCol), outputSchema($(leafCol)).metadata)
}
if (predictionColNames.nonEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 53b29102f01be..f24eeff682110 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -27,7 +27,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
-import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.optim._
@@ -213,7 +213,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
}
if (hasLinkPredictionCol) {
- SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
+ val attr = NumericAttribute.defaultAttr
+ .withName($(linkPredictionCol))
+ SchemaUtils.appendColumn(newSchema, attr.toStructField())
} else {
newSchema
}
@@ -1043,6 +1045,8 @@ class GeneralizedLinearRegressionModel private[ml] (
}
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
+ val outputSchema = transformSchema(dataset.schema, logging = true)
+
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
var outputData = dataset
var numColsOutput = 0
@@ -1050,17 +1054,20 @@ class GeneralizedLinearRegressionModel private[ml] (
if (hasLinkPredictionCol) {
val predLinkUDF = udf((features: Vector, offset: Double) => predictLink(features, offset))
outputData = outputData
- .withColumn($(linkPredictionCol), predLinkUDF(col($(featuresCol)), offset))
+ .withColumn($(linkPredictionCol), predLinkUDF(col($(featuresCol)), offset),
+ outputSchema($(linkPredictionCol)).metadata)
numColsOutput += 1
}
if ($(predictionCol).nonEmpty) {
if (hasLinkPredictionCol) {
val predUDF = udf((eta: Double) => familyAndLink.fitted(eta))
- outputData = outputData.withColumn($(predictionCol), predUDF(col($(linkPredictionCol))))
+ outputData = outputData.withColumn($(predictionCol), predUDF(col($(linkPredictionCol))),
+ outputSchema($(predictionCol)).metadata)
} else {
val predUDF = udf((features: Vector, offset: Double) => predict(features, offset))
- outputData = outputData.withColumn($(predictionCol), predUDF(col($(featuresCol)), offset))
+ outputData = outputData.withColumn($(predictionCol), predUDF(col($(featuresCol)), offset),
+ outputSchema($(predictionCol)).metadata)
}
numColsOutput += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 47f9e4bfb8333..d12e5daabebbf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -240,7 +240,7 @@ class IsotonicRegressionModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
val predict = dataset.schema($(featuresCol)).dataType match {
case DoubleType =>
udf { feature: Double => oldModel.predict(feature) }
@@ -248,12 +248,17 @@ class IsotonicRegressionModel private[ml] (
val idx = $(featureIndex)
udf { features: Vector => oldModel.predict(features(idx)) }
}
- dataset.withColumn($(predictionCol), predict(col($(featuresCol))))
+ dataset.withColumn($(predictionCol), predict(col($(featuresCol))),
+ outputSchema($(predictionCol)).metadata)
}
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema, fitting = false)
+ var outputSchema = validateAndTransformSchema(schema, fitting = false)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
+ }
+ outputSchema
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index c3afab57a49c7..fa4dbbb47079f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -33,6 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructType
/**
* Random Forest
@@ -192,8 +193,17 @@ class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
+ @Since("1.4.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(leafCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol)))
+ }
+ outputSchema
+ }
+
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema, logging = true)
+ val outputSchema = transformSchema(dataset.schema, logging = true)
var predictionColNames = Seq.empty[String]
var predictionColumns = Seq.empty[Column]
@@ -204,12 +214,14 @@ class RandomForestRegressionModel private[ml] (
val predictUDF = udf { features: Vector => bcastModel.value.predict(features) }
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
+ .as($(predictionCol), outputSchema($(predictionCol)).metadata)
}
if ($(leafCol).nonEmpty) {
val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) }
predictionColNames :+= $(leafCol)
predictionColumns :+= leafUDF(col($(featuresCol)))
+ .as($(leafCol), outputSchema($(leafCol)).metadata)
}
if (predictionColNames.nonEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 10895d4fd11d9..3009b733d4fb7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.{Param, Params}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData
@@ -89,6 +90,18 @@ private[spark] trait DecisionTreeModel {
}
}
+ private[ml] lazy val numLeave: Int =
+ leafIterator(rootNode).size
+
+ private[ml] lazy val leafAttr = {
+ NominalAttribute.defaultAttr
+ .withNumValues(numLeave)
+ }
+
+ private[ml] def getLeafField(leafCol: String) = {
+ leafAttr.withName(leafCol).toStructField()
+ }
+
@transient private lazy val leafIndices: Map[LeafNode, Int] = {
leafIterator(rootNode).zipWithIndex.toMap
}
@@ -146,6 +159,10 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
val indices = trees.map(_.predictLeaf(features))
Vectors.dense(indices)
}
+
+ private[ml] def getLeafField(leafCol: String) = {
+ new AttributeGroup(leafCol, attrs = trees.map(_.leafAttr)).toStructField()
+ }
}
private[ml] object TreeEnsembleModel {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index c3894ebdd1785..752069daf8910 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.util
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.types._
@@ -106,6 +107,91 @@ private[spark] object SchemaUtils {
StructType(schema.fields :+ col)
}
+ /**
+ * Update the size of a ML Vector column. If this column do not exist, append it.
+ * @param schema input schema
+ * @param colName column name
+ * @param size number of features
+ * @return new schema
+ */
+ def updateAttributeGroupSize(
+ schema: StructType,
+ colName: String,
+ size: Int): StructType = {
+ require(size > 0)
+ val attrGroup = new AttributeGroup(colName, size)
+ val field = attrGroup.toStructField
+ updateField(schema, field, true)
+ }
+
+ /**
+ * Update the number of values of an existing column. If this column do not exist, append it.
+ * @param schema input schema
+ * @param colName column name
+ * @param numValues number of values.
+ * @return new schema
+ */
+ def updateNumValues(
+ schema: StructType,
+ colName: String,
+ numValues: Int): StructType = {
+ val attr = NominalAttribute.defaultAttr
+ .withName(colName)
+ .withNumValues(numValues)
+ val field = attr.toStructField
+ updateField(schema, field, true)
+ }
+
+ /**
+ * Update the numeric meta of an existing column. If this column do not exist, append it.
+ * @param schema input schema
+ * @param colName column name
+ * @return new schema
+ */
+ def updateNumeric(
+ schema: StructType,
+ colName: String): StructType = {
+ val attr = NumericAttribute.defaultAttr
+ .withName(colName)
+ val field = attr.toStructField
+ updateField(schema, field, true)
+ }
+
+ /**
+ * Update the metadata of an existing column. If this column do not exist, append it.
+ * @param schema input schema
+ * @param field struct field
+ * @param overwriteMetadata whether to overwrite the metadata. If true, the metadata in the
+ * schema will be overwritten. If false, the metadata in `field`
+ * and `schema` will be merged to generate output metadata.
+ * @return new schema
+ */
+ def updateField(
+ schema: StructType,
+ field: StructField,
+ overwriteMetadata: Boolean = true): StructType = {
+ if (schema.fieldNames.contains(field.name)) {
+ val newFields = schema.fields.map { f =>
+ if (f.name == field.name) {
+ if (overwriteMetadata) {
+ field
+ } else {
+ val newMeta = new MetadataBuilder()
+ .withMetadata(field.metadata)
+ .withMetadata(f.metadata)
+ .build()
+ StructField(field.name, field.dataType, field.nullable, newMeta)
+ }
+ } else {
+ f
+ }
+ }
+ StructType(newFields)
+ } else {
+ appendColumn(schema, field)
+ }
+ }
+
/**
* Check whether the given column in the schema is one of the supporting vector type: Vector,
* Array[Float]. Array[Double]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 3ebf8a83a892c..fd5af5b954150 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -249,6 +249,13 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
+ newTree.setLeafCol("predictedLeafId")
+
+ val transformed = newTree.transform(newData)
+ checkNominalOnDF(transformed, "prediction", newTree.numClasses)
+ checkNominalOnDF(transformed, "predictedLeafId", newTree.numLeave)
+ checkVectorSizeOnDF(transformed, "rawPrediction", newTree.numClasses)
+ checkVectorSizeOnDF(transformed, "probability", newTree.numClasses)
MLTestingUtils.checkCopyAndUids(dt, newTree)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index fdca71f8911c6..ffd4b5e6d3055 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -473,6 +473,13 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
.setCheckpointInterval(5)
.setSeed(123)
val model = gbt.fit(df)
+ model.setLeafCol("predictedLeafId")
+
+ val transformed = model.transform(df)
+ checkNominalOnDF(transformed, "prediction", model.numClasses)
+ checkVectorSizeOnDF(transformed, "predictedLeafId", model.trees.length)
+ checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses)
+ checkVectorSizeOnDF(transformed, "probability", model.numClasses)
model.trees.foreach (i => {
assert(i.getMaxDepth === model.getMaxDepth)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index dc38f17d296f2..b23b4f4ac0d26 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -112,8 +112,13 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
assert(lsvc.getFeaturesCol === "features")
assert(lsvc.getPredictionCol === "prediction")
assert(lsvc.getRawPredictionCol === "rawPrediction")
+
val model = lsvc.setMaxIter(5).fit(smallBinaryDataset)
- model.transform(smallBinaryDataset)
+ val transformed = model.transform(smallBinaryDataset)
+ checkNominalOnDF(transformed, "prediction", model.numClasses)
+ checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses)
+
+ transformed
.select("label", "prediction", "rawPrediction")
.collect()
assert(model.getThreshold === 0.0)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 60c9cce6a4879..38bdfded9693e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -155,8 +155,14 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest {
assert(!lr.isDefined(lr.weightCol))
assert(lr.getFitIntercept)
assert(lr.getStandardization)
+
val model = lr.fit(smallBinaryDataset)
- model.transform(smallBinaryDataset)
+ val transformed = model.transform(smallBinaryDataset)
+ checkNominalOnDF(transformed, "prediction", model.numClasses)
+ checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses)
+ checkVectorSizeOnDF(transformed, "probability", model.numClasses)
+
+ transformed
.select("label", "probability", "prediction", "rawPrediction")
.collect()
assert(model.getThreshold === 0.5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index adffd83ab1bd1..024a3870d8bca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -81,6 +81,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
assert(ovaModel.numClasses === numClasses)
val transformedDataset = ovaModel.transform(dataset)
+ checkNominalOnDF(transformedDataset, "prediction", ovaModel.numClasses)
+ checkVectorSizeOnDF(transformedDataset, "rawPrediction", ovaModel.numClasses)
// check for label metadata in prediction col
val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 5958bfcf5ea6d..379d3bd128a5d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -242,6 +242,13 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
val model = rf.fit(df)
+ model.setLeafCol("predictedLeafId")
+
+ val transformed = model.transform(df)
+ checkNominalOnDF(transformed, "prediction", model.numClasses)
+ checkVectorSizeOnDF(transformed, "predictedLeafId", model.trees.length)
+ checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses)
+ checkVectorSizeOnDF(transformed, "probability", model.numClasses)
model.trees.foreach (i => {
assert(i.getMaxDepth === model.getMaxDepth)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 5288595d2e239..7ac7b64adfdab 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -174,6 +174,8 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest {
.setSeed(1)
.fit(df)
val predictionDf = model.transform(df)
+ checkNominalOnDF(predictionDf, "prediction", model.getK)
+
assert(predictionDf.select("prediction").distinct().count() == 3)
val predictionsMap = predictionDf.collect().map(row =>
row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 133536f763f4e..e570693c90e6e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -71,10 +71,15 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
assert(gm.getK === 2)
assert(gm.getFeaturesCol === "features")
assert(gm.getPredictionCol === "prediction")
+ assert(gm.getProbabilityCol === "probability")
assert(gm.getMaxIter === 100)
assert(gm.getTol === 0.01)
val model = gm.setMaxIter(1).fit(dataset)
+ val transformed = model.transform(dataset)
+ checkNominalOnDF(transformed, "prediction", model.weights.length)
+ checkVectorSizeOnDF(transformed, "probability", model.weights.length)
+
MLTestingUtils.checkCopyAndUids(gm, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index e3c82fafca218..f6b1a8e9d6df3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -60,6 +60,9 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes
assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN)
val model = kmeans.setMaxIter(1).fit(dataset)
+ val transformed = model.transform(dataset)
+ checkNominalOnDF(transformed, "prediction", model.clusterCenters.length)
+
MLTestingUtils.checkCopyAndUids(kmeans, model)
assert(model.hasSummary)
val copiedModel = model.copy(ParamMap.empty)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 079dabb3665be..19645b517d79c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.jtransforms.dct.DoubleDCT_1D
+import org.scalatest.exceptions.TestFailedException
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
@@ -74,5 +75,24 @@ class DCTSuite extends MLTest with DefaultReadWriteTest {
case Row(resultVec: Vector, wantedVec: Vector) =>
assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
}
+
+ val vectorSize = dataset
+ .select("vec")
+ .map { case Row(vec: Vector) => vec.size }
+ .head()
+
+ // Can not infer size of ouput vector, since no metadata is provided
+ intercept[TestFailedException] {
+ val transformed = transformer.transform(dataset)
+ checkVectorSizeOnDF(transformed, "resultVec", vectorSize)
+ }
+
+ val dataset2 = new VectorSizeHint()
+ .setSize(vectorSize)
+ .setInputCol("vec")
+ .transform(dataset)
+
+ val transformed2 = transformer.transform(dataset2)
+ checkVectorSizeOnDF(transformed2, "resultVec", vectorSize)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index 73b2b82daaf43..b4e144ea5ba5e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -68,6 +68,9 @@ class IDFSuite extends MLTest with DefaultReadWriteTest {
.setOutputCol("idfValue")
val idfModel = idfEst.fit(df)
+ val transformed = idfModel.transform(df)
+ checkVectorSizeOnDF(transformed, "idfValue", idfModel.idf.size)
+
MLTestingUtils.checkCopyAndUids(idfEst, idfModel)
testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
index 8dd0f0cb91e37..5de938fa40c4d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala
@@ -44,6 +44,9 @@ class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest {
.setOutputCol("scaled")
val model = scaler.fit(df)
+ val transformed = model.transform(df)
+ checkVectorSizeOnDF(transformed, "scaled", model.maxAbs.size)
+
testTransformer[(Vector, Vector)](df, model, "expected", "scaled") {
case Row(expectedVec: Vector, actualVec: Vector) =>
assert(expectedVec === actualVec,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 2d965f2ca2c54..9b2b0c48f4f61 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -46,6 +46,9 @@ class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest {
.setMax(5)
val model = scaler.fit(df)
+ val transformed = model.transform(df)
+ checkVectorSizeOnDF(transformed, "scaled", model.originalMin.size)
+
testTransformer[(Vector, Vector)](df, model, "expected", "scaled") {
case Row(vector1: Vector, vector2: Vector) =>
assert(vector1 === vector2, "Transformed vector is different with expected.")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index eff57f1223af4..d97df0050d74e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import org.scalatest.exceptions.TestFailedException
+
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
import org.apache.spark.ml.util.TestingUtils._
@@ -81,6 +83,22 @@ class NormalizerSuite extends MLTest with DefaultReadWriteTest {
assertTypeOfVector(normalized, features)
assertValues(normalized, expected)
}
+
+ val vectorSize = data.head.size
+
+ // Can not infer size of output vector, since no metadata is provided
+ intercept[TestFailedException] {
+ val transformed = normalizer.transform(dataFrame)
+ checkVectorSizeOnDF(transformed, "normalized", vectorSize)
+ }
+
+ val dataFrame2 = new VectorSizeHint()
+ .setSize(vectorSize)
+ .setInputCol("features")
+ .transform(dataFrame)
+
+ val transformed2 = normalizer.transform(dataFrame2)
+ checkVectorSizeOnDF(transformed2, "normalized", vectorSize)
}
test("Normalization with setter") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 531b1d7c4d9f7..88c9867337e7c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -58,6 +58,8 @@ class PCASuite extends MLTest with DefaultReadWriteTest {
.setK(3)
val pcaModel = pca.fit(df)
+ val transformed = pcaModel.transform(df)
+ checkVectorSizeOnDF(transformed, "pca_features", pcaModel.getK)
MLTestingUtils.checkCopyAndUids(pca, pcaModel)
testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index d28f1f4240ad0..11e1847ef235e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -61,6 +61,9 @@ class Word2VecSuite extends MLTest with DefaultReadWriteTest {
.setSeed(42L)
val model = w2v.fit(docDF)
+ val transformed = model.transform(docDF)
+ checkVectorSizeOnDF(transformed, "result", model.getVectorSize)
+
MLTestingUtils.checkCopyAndUids(w2v, model)
// These expectations are just magic values, characterizing the current
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 978a3cbe54c1e..3e1e2ad6a7f55 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -24,10 +24,11 @@ import org.scalatest.Suite
import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext, TestUtils}
import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
import org.apache.spark.ml.{Model, PredictionModel, Transformer}
+import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.test.TestSparkSession
@@ -64,6 +65,38 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
}
}
+ private[ml] def checkVectorSizeOnDF(
+ dataframe: DataFrame,
+ vecColName: String,
+ vecSize: Int): Unit = {
+ import dataframe.sparkSession.implicits._
+ val group = AttributeGroup.fromStructField(dataframe.schema(vecColName))
+ assert(group.size === vecSize,
+ s"the vector size obtained from schema should be $vecSize, but got ${group.size}")
+ val sizeUDF = udf { vector: Vector => vector.size }
+ assert(dataframe.select(sizeUDF(col(vecColName)))
+ .as[Int]
+ .collect()
+ .forall(_ === vecSize))
+ }
+
+ private[ml] def checkNominalOnDF(
+ dataframe: DataFrame,
+ colName: String,
+ numValues: Int): Unit = {
+ import dataframe.sparkSession.implicits._
+ val n = Attribute.fromStructField(dataframe.schema(colName)) match {
+ case binAttr: BinaryAttribute => Some(2)
+ case nomAttr: NominalAttribute => nomAttr.getNumValues
+ }
+ assert(n.isDefined && n.get === numValues,
+ s"the number of values obtained from schema should be $numValues, but got $n")
+ assert(dataframe.select(colName)
+ .as[Double]
+ .collect()
+ .forall(v => v === v.toInt && v >= 0 && v < numValues))
+ }
+
private[util] def testTransformerOnStreamData[A : Encoder](
dataframe: DataFrame,
transformer: Transformer,