diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 9eac8ed22a3f6..98dd692cbe55d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -88,8 +88,9 @@ private[ml] trait PredictorParams extends Params * and put it in an RDD with strong types. * Validate the output instances with the given function. */ - protected def extractInstances(dataset: Dataset[_], - validateInstance: Instance => Unit): RDD[Instance] = { + protected def extractInstances( + dataset: Dataset[_], + validateInstance: Instance => Unit): RDD[Instance] = { extractInstances(dataset).map { instance => validateInstance(instance) instance @@ -222,7 +223,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, protected def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = false, featuresDataType) + var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol)) + } + outputSchema } /** @@ -244,10 +249,12 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } protected def transformImpl(dataset: Dataset[_]): DataFrame = { - val predictUDF = udf { (features: Any) => + val outputSchema = transformSchema(dataset.schema, logging = true) + val predictUDF = udf { features: Any => predict(features.asInstanceOf[FeaturesType]) } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))), + outputSchema($(predictionCol)).metadata) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index a3a2b55adc25d..7874fc29db6c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -117,9 +117,10 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] } override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val transformUDF = udf(this.createTransformFunc, outputDataType) - dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))), + outputSchema($(outputCol)).metadata) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 3bff236677e6b..be552e6be0b50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -48,8 +48,9 @@ private[spark] trait ClassifierParams * and put it in an RDD with strong types. * Validates the label on the classifier is a valid integer in the range [0, numClasses). */ - protected def extractInstances(dataset: Dataset[_], - numClasses: Int): RDD[Instance] = { + protected def extractInstances( + dataset: Dataset[_], + numClasses: Int): RDD[Instance] = { val validateInstance = (instance: Instance) => { val label = instance.label require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" + @@ -183,6 +184,19 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur /** Number of classes (values which the label can take). */ def numClasses: Int + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumValues(schema, + $(predictionCol), numClasses) + } + if ($(rawPredictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(rawPredictionCol), numClasses) + } + outputSchema + } + /** * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by * parameters: @@ -193,29 +207,31 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @return transformed dataset */ override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. var outputData = dataset var numColsOutput = 0 if (getRawPredictionCol != "") { - val predictRawUDF = udf { (features: Any) => + val predictRawUDF = udf { features: Any => predictRaw(features.asInstanceOf[FeaturesType]) } - outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)), + outputSchema($(rawPredictionCol)).metadata) numColsOutput += 1 } if (getPredictionCol != "") { - val predUDF = if (getRawPredictionCol != "") { + val predCol = if (getRawPredictionCol != "") { udf(raw2prediction _).apply(col(getRawPredictionCol)) } else { - val predictUDF = udf { (features: Any) => + val predictUDF = udf { features: Any => predict(features.asInstanceOf[FeaturesType]) } predictUDF(col(getFeaturesCol)) } - outputData = outputData.withColumn(getPredictionCol, predUDF) + outputData = outputData.withColumn(getPredictionCol, predCol, + outputSchema($(predictionCol)).metadata) numColsOutput += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index e02109375373e..d10f684f0dcf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructType /** * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) @@ -202,13 +203,23 @@ class DecisionTreeClassificationModel private[ml] ( rootNode.predictImpl(features).prediction } + @Since("3.0.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(leafCol).nonEmpty) { + outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol))) + } + outputSchema + } + override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val outputData = super.transform(dataset) if ($(leafCol).nonEmpty) { val leafUDF = udf { features: Vector => predictLeaf(features) } - outputData.withColumn($(leafCol), leafUDF(col($(featuresCol)))) + outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))), + outputSchema($(leafCol)).metadata) } else { outputData } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index e1f5338f34899..6e54e0f15b85c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType /** * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) @@ -291,13 +292,23 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(leafCol).nonEmpty) { + outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol))) + } + outputSchema + } + override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val outputData = super.transform(dataset) if ($(leafCol).nonEmpty) { val leafUDF = udf { features: Vector => predictLeaf(features) } - outputData.withColumn($(leafCol), leafUDF(col($(featuresCol)))) + outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))), + outputSchema($(leafCol)).metadata) } else { outputData } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 51a624795cdd4..bbf8e8fc90ad5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -161,13 +161,23 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) + var outputSchema = validateAndTransformSchema(schema, fitting = false, + getClassifier.featuresDataType) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumValues(outputSchema, + $(predictionCol), numClasses) + } + if ($(rawPredictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(rawPredictionCol), numClasses) + } + outputSchema } @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { // Check schema - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) if (getPredictionCol.isEmpty && getRawPredictionCol.isEmpty) { logWarning(s"$uid: OneVsRestModel.transform() does nothing" + @@ -230,6 +240,7 @@ final class OneVsRestModel private[ml] ( predictionColNames :+= getRawPredictionCol predictionColumns :+= rawPredictionUDF(col(accColName)) + .as($(rawPredictionCol), outputSchema($(rawPredictionCol)).metadata) } if (getPredictionCol.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 2171ac335e7b8..2e4d69330d132 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -90,6 +90,15 @@ abstract class ProbabilisticClassificationModel[ set(thresholds, value).asInstanceOf[M] } + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(probabilityCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(probabilityCol), numClasses) + } + outputSchema + } + /** * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by * parameters: @@ -101,7 +110,7 @@ abstract class ProbabilisticClassificationModel[ * @return transformed dataset */ override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".transform() called with non-matching numClasses and thresholds.length." + @@ -113,36 +122,39 @@ abstract class ProbabilisticClassificationModel[ var outputData = dataset var numColsOutput = 0 if ($(rawPredictionCol).nonEmpty) { - val predictRawUDF = udf { (features: Any) => + val predictRawUDF = udf { features: Any => predictRaw(features.asInstanceOf[FeaturesType]) } - outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)), + outputSchema($(rawPredictionCol)).metadata) numColsOutput += 1 } if ($(probabilityCol).nonEmpty) { - val probUDF = if ($(rawPredictionCol).nonEmpty) { + val probCol = if ($(rawPredictionCol).nonEmpty) { udf(raw2probability _).apply(col($(rawPredictionCol))) } else { - val probabilityUDF = udf { (features: Any) => + val probabilityUDF = udf { features: Any => predictProbability(features.asInstanceOf[FeaturesType]) } probabilityUDF(col($(featuresCol))) } - outputData = outputData.withColumn($(probabilityCol), probUDF) + outputData = outputData.withColumn($(probabilityCol), probCol, + outputSchema($(probabilityCol)).metadata) numColsOutput += 1 } if ($(predictionCol).nonEmpty) { - val predUDF = if ($(rawPredictionCol).nonEmpty) { + val predCol = if ($(rawPredictionCol).nonEmpty) { udf(raw2prediction _).apply(col($(rawPredictionCol))) } else if ($(probabilityCol).nonEmpty) { udf(probability2prediction _).apply(col($(probabilityCol))) } else { - val predictUDF = udf { (features: Any) => + val predictUDF = udf { features: Any => predict(features.asInstanceOf[FeaturesType]) } predictUDF(col($(featuresCol))) } - outputData = outputData.withColumn($(predictionCol), predUDF) + outputData = outputData.withColumn($(predictionCol), predCol, + outputSchema($(predictionCol)).metadata) numColsOutput += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bc28d783ed962..f88fc2a6a0914 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructType /** * Random Forest learning algorithm for @@ -210,13 +211,23 @@ class RandomForestClassificationModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(leafCol).nonEmpty) { + outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol))) + } + outputSchema + } + override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val outputData = super.transform(dataset) if ($(leafCol).nonEmpty) { val leafUDF = udf { features: Vector => predictLeaf(features) } - outputData.withColumn($(leafCol), leafUDF(col($(featuresCol)))) + outputData.withColumn($(leafCol), leafUDF(col($(featuresCol))), + outputSchema($(leafCol)).metadata) } else { outputData } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 5f2316fa7ce18..79760d69489c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -110,15 +110,21 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), - predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)), + outputSchema($(predictionCol)).metadata) } @Since("2.0.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumValues(outputSchema, + $(predictionCol), parentModel.k) + } + outputSchema } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 916f326ab5615..9d00d6a8bcbe4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -112,7 +112,7 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) var outputData = dataset @@ -120,17 +120,20 @@ class GaussianMixtureModel private[ml] ( if ($(probabilityCol).nonEmpty) { val probUDF = udf((vector: Vector) => predictProbability(vector)) - outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol)) + outputData = outputData.withColumn($(probabilityCol), probUDF(vectorCol), + outputSchema($(probabilityCol)).metadata) numColsOutput += 1 } if ($(predictionCol).nonEmpty) { if ($(probabilityCol).nonEmpty) { val predUDF = udf((vector: Vector) => vector.argmax) - outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol)))) + outputData = outputData.withColumn($(predictionCol), predUDF(col($(probabilityCol))), + outputSchema($(predictionCol)).metadata) } else { val predUDF = udf((vector: Vector) => predict(vector)) - outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol)) + outputData = outputData.withColumn($(predictionCol), predUDF(vectorCol), + outputSchema($(predictionCol)).metadata) } numColsOutput += 1 } @@ -144,7 +147,16 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumValues(outputSchema, + $(predictionCol), weights.length) + } + if ($(probabilityCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(probabilityCol), weights.length) + } + outputSchema } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index caeded400f9aa..5cbba6c77f9fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -127,17 +127,23 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), - predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + predictUDF(DatasetUtils.columnToVector(dataset, getFeaturesCol)), + outputSchema($(predictionCol)).metadata) } @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumValues(outputSchema, + $(predictionCol), parentModel.k) + } + outputSchema } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 9b0005b3747dc..e30be8c20dcc3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -459,12 +459,13 @@ abstract class LDAModel private[ml] ( */ @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val func = getTopicDistributionMethod val transformer = udf(func) dataset.withColumn($(topicDistributionCol), - transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol))) + transformer(DatasetUtils.columnToVector(dataset, getFeaturesCol)), + outputSchema($(topicDistributionCol)).metadata) } /** @@ -504,7 +505,12 @@ abstract class LDAModel private[ml] ( @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(topicDistributionCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(topicDistributionCol), oldLocalModel.k) + } + outputSchema } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 07a4f91443bc5..381ab6eb7355b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer -import org.apache.spark.ml.attribute.BinaryAttribute +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -193,7 +193,12 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) case DoubleType => BinaryAttribute.defaultAttr.withName(outputColName).toStructField() case _: VectorUDT => - StructField(outputColName, new VectorUDT) + val size = AttributeGroup.fromStructField(schema(inputColName)).size + if (size < 0) { + StructField(outputColName, new VectorUDT) + } else { + new AttributeGroup(outputColName, numAttributes = size).toStructField() + } case _ => throw new IllegalArgumentException(s"Data type $inputType is not supported.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 9103e4feac454..76f4f944f11d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -289,8 +289,7 @@ final class ChiSqSelectorModel private[ml] ( override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) - val outputFields = schema.fields :+ newField - StructType(outputFields) + SchemaUtils.appendColumn(schema, newField) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index c58d44d492342..7ba6f640b1e49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -300,7 +300,7 @@ class CountVectorizerModel( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict)) @@ -326,14 +326,19 @@ class CountVectorizerModel( Vectors.sparse(dictBr.value.size, effectiveCounts) } - val attrs = vocabulary.map(_ => new NumericAttribute).asInstanceOf[Array[Attribute]] - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), metadata) + dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), + outputSchema($(outputCol)).metadata) } @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + val attrs: Array[Attribute] = vocabulary.map(_ => new NumericAttribute) + val field = new AttributeGroup($(outputCol), attrs).toStructField() + outputSchema = SchemaUtils.updateField(outputSchema, field) + } + outputSchema } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index e2167f01281da..d057e5a62e507 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -21,10 +21,11 @@ import org.jtransforms.dct._ import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.BooleanParam import org.apache.spark.ml.util._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ /** * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero @@ -75,6 +76,18 @@ class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) override protected def outputDataType: DataType = new VectorUDT + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) { + val size = AttributeGroup.fromStructField(schema($(inputCol))).size + if (size >= 0) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), size) + } + } + outputSchema + } + @Since("3.0.0") override def toString: String = { s"DCT: uid=$uid, inverse=$inverse" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 227c13d60fd8f..3b328f2fd8cee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -21,10 +21,10 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.Param -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature.{ElementwiseProduct => OldElementwiseProduct} import org.apache.spark.mllib.linalg.{Vectors => OldVectors} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ /** * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a @@ -82,6 +82,15 @@ class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: Stri override protected def outputDataType: DataType = new VectorUDT() + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), $(scalingVec).size) + } + outputSchema + } + @Since("3.0.0") override def toString: String = { s"ElementwiseProduct: uid=$uid" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 5f4103abcf50f..e6f124ef7d666 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -131,7 +131,7 @@ class IDFModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val func = { vector: Vector => vector match { @@ -149,12 +149,18 @@ class IDFModel private[ml] ( } val transformer = udf(func) - dataset.withColumn($(outputCol), transformer(col($(inputCol)))) + dataset.withColumn($(outputCol), transformer(col($(inputCol))), + outputSchema($(outputCol)).metadata) } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), idf.size) + } + outputSchema } @Since("1.4.1") @@ -180,7 +186,7 @@ class IDFModel private[ml] ( @Since("3.0.0") override def toString: String = { - s"IDFModel: uid=$uid, numDocs=$numDocs" + s"IDFModel: uid=$uid, numDocs=$numDocs, numFeatures=${idf.size}" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 6bab70e502ed7..2d48a5f9f4915 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -117,19 +117,25 @@ class MaxAbsScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val scale = maxAbs.toArray.map { v => if (v == 0) 1.0 else 1 / v } val func = StandardScalerModel.getTransformFunc( Array.empty, scale, false, true) val transformer = udf(func) - dataset.withColumn($(outputCol), transformer(col($(inputCol)))) + dataset.withColumn($(outputCol), transformer(col($(inputCol))), + outputSchema($(outputCol)).metadata) } @Since("2.0.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), maxAbs.size) + } + outputSchema } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index e381a0435e9eb..c84892c974b90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -174,7 +174,7 @@ class MinMaxScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val numFeatures = originalMax.size val scale = $(max) - $(min) @@ -210,12 +210,18 @@ class MinMaxScalerModel private[ml] ( Vectors.dense(values).compressed } - dataset.withColumn($(outputCol), transformer(col($(inputCol)))) + dataset.withColumn($(outputCol), transformer(col($(inputCol))), + outputSchema($(outputCol)).metadata) } @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), originalMin.size) + } + outputSchema } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index d129c2b2c2dc1..4c7583b8381dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -19,12 +19,13 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vectors => OldVectors} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ /** * Normalize a vector to have unit norm using the given p-norm. @@ -66,6 +67,19 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) override protected def outputDataType: DataType = new VectorUDT() + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) { + val size = AttributeGroup.fromStructField(schema($(inputCol))).size + if (size >= 0) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), size) + } + } + outputSchema + } + @Since("3.0.0") override def toString: String = { s"Normalizer: uid=$uid, p=${$(p)}" diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 27a3854d39b47..9eeb4c8ca2506 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{DenseMatrix => OldDenseMatrix, Vectors => OldVectors} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.VersionUtils.majorVersion /** @@ -52,10 +52,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) require(!schema.fieldNames.contains($(outputCol)), s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + SchemaUtils.updateAttributeGroupSize(schema, $(outputCol), $(k)) } - } /** @@ -145,16 +143,22 @@ class PCAModel private[ml] ( */ @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val transposed = pc.transpose val transformer = udf { vector: Vector => transposed.multiply(vector) } - dataset.withColumn($(outputCol), transformer(col($(inputCol)))) + dataset.withColumn($(outputCol), transformer(col($(inputCol))), + outputSchema($(outputCol)).metadata) } @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), $(k)) + } + outputSchema } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala index 1b9b8082931a5..f02ef0dfc70f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala @@ -227,7 +227,7 @@ class RobustScalerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val shift = if ($(withCentering)) median.toArray else Array.emptyDoubleArray val scale = if ($(withScaling)) { @@ -238,11 +238,17 @@ class RobustScalerModel private[ml] ( shift, scale, $(withCentering), $(withScaling)) val transformer = udf(func) - dataset.withColumn($(outputCol), transformer(col($(inputCol)))) + dataset.withColumn($(outputCol), transformer(col($(inputCol))), + outputSchema($(outputCol)).metadata) } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), median.size) + } + outputSchema } override def copy(extra: ParamMap): RobustScalerModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 8d4d4197e7db2..c6b1b29a6d9bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -157,7 +157,7 @@ class StandardScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val shift = if ($(withMean)) mean.toArray else Array.emptyDoubleArray val scale = if ($(withStd)) { std.toArray.map { v => if (v == 0) 0.0 else 1.0 / v } @@ -166,12 +166,18 @@ class StandardScalerModel private[ml] ( val func = getTransformFunc(shift, scale, $(withMean), $(withStd)) val transformer = udf(func) - dataset.withColumn($(outputCol), transformer(col($(inputCol)))) + dataset.withColumn($(outputCol), transformer(col($(inputCol))), + outputSchema($(outputCol)).metadata) } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), mean.size) + } + outputSchema } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index b84b8af4e8a94..45bb4b8e6e65d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -153,8 +153,7 @@ final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: Stri } val numFeaturesSelected = $(indices).length + $(names).length val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected) - val outputFields = schema.fields :+ outputAttr.toStructField() - StructType(outputFields) + SchemaUtils.appendColumn(schema, outputAttr.toStructField) } @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 81dde0315c190..bbfcbfbe038ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -288,15 +288,16 @@ class Word2VecModel private[ml] ( */ @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) .map(identity) // mapValues doesn't return a serializable map (SI-7005) val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) val d = $(vectorSize) + val emptyVec = Vectors.sparse(d, Array.emptyIntArray, Array.emptyDoubleArray) val word2Vec = udf { sentence: Seq[String] => if (sentence.isEmpty) { - Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) + emptyVec } else { val sum = Vectors.zeros(d) sentence.foreach { word => @@ -308,12 +309,18 @@ class Word2VecModel private[ml] ( sum } } - dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) + dataset.withColumn($(outputCol), word2Vec(col($(inputCol))), + outputSchema($(outputCol)).metadata) } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + var outputSchema = validateAndTransformSchema(schema) + if ($(outputCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(outputCol), $(vectorSize)) + } + outputSchema } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index faf77252cb738..7079ac89dcc93 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -349,7 +349,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) var predictionColNames = Seq.empty[String] var predictionColumns = Seq.empty[Column] @@ -358,12 +358,14 @@ class AFTSurvivalRegressionModel private[ml] ( val predictUDF = udf { features: Vector => predict(features) } predictionColNames :+= $(predictionCol) predictionColumns :+= predictUDF(col($(featuresCol))) + .as($(predictionCol), outputSchema($(predictionCol)).metadata) } if (hasQuantilesCol) { val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} predictionColNames :+= $(quantilesCol) predictionColumns :+= predictQuantilesUDF(col($(featuresCol))) + .as($(quantilesCol), outputSchema($(quantilesCol)).metadata) } if (predictionColNames.nonEmpty) { @@ -377,7 +379,15 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = false) + var outputSchema = validateAndTransformSchema(schema, fitting = false) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol)) + } + if (isDefined(quantilesCol) && $(quantilesCol).nonEmpty) { + outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema, + $(quantilesCol), $(quantileProbabilities).length) + } + outputSchema } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 4a97997a1deb8..447e6f90a44e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -36,7 +36,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeMo import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions._ - +import org.apache.spark.sql.types.StructType /** * Decision tree @@ -202,9 +202,21 @@ class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).impurityStats.calculate() } + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumeric(outputSchema, $(varianceCol)) + } + if ($(leafCol).nonEmpty) { + outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol))) + } + outputSchema + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) var predictionColNames = Seq.empty[String] var predictionColumns = Seq.empty[Column] @@ -213,18 +225,21 @@ class DecisionTreeRegressionModel private[ml] ( val predictUDF = udf { features: Vector => predict(features) } predictionColNames :+= $(predictionCol) predictionColumns :+= predictUDF(col($(featuresCol))) + .as($(predictionCol), outputSchema($(predictionCol)).metadata) } if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { val predictVarianceUDF = udf { features: Vector => predictVariance(features) } predictionColNames :+= $(varianceCol) predictionColumns :+= predictVarianceUDF(col($(featuresCol))) + .as($(varianceCol), outputSchema($(varianceCol)).metadata) } if ($(leafCol).nonEmpty) { val leafUDF = udf { features: Vector => predictLeaf(features) } predictionColNames :+= $(leafCol) predictionColumns :+= leafUDF(col($(featuresCol))) + .as($(leafCol), outputSchema($(leafCol)).metadata) } if (predictionColNames.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 700f7a2075a91..eb0f2362af570 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -35,6 +35,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType /** * Gradient-Boosted Trees (GBTs) @@ -255,8 +256,17 @@ class GBTRegressionModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(leafCol).nonEmpty) { + outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol))) + } + outputSchema + } + override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) var predictionColNames = Seq.empty[String] var predictionColumns = Seq.empty[Column] @@ -267,12 +277,14 @@ class GBTRegressionModel private[ml]( val predictUDF = udf { features: Vector => bcastModel.value.predict(features) } predictionColNames :+= $(predictionCol) predictionColumns :+= predictUDF(col($(featuresCol))) + .as($(featuresCol), outputSchema($(featuresCol)).metadata) } if ($(leafCol).nonEmpty) { val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) } predictionColNames :+= $(leafCol) predictionColumns :+= leafUDF(col($(featuresCol))) + .as($(leafCol), outputSchema($(leafCol)).metadata) } if (predictionColNames.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 53b29102f01be..f24eeff682110 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ @@ -213,7 +213,9 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam } if (hasLinkPredictionCol) { - SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) + val attr = NumericAttribute.defaultAttr + .withName($(linkPredictionCol)) + SchemaUtils.appendColumn(newSchema, attr.toStructField()) } else { newSchema } @@ -1043,6 +1045,8 @@ class GeneralizedLinearRegressionModel private[ml] ( } override protected def transformImpl(dataset: Dataset[_]): DataFrame = { + val outputSchema = transformSchema(dataset.schema, logging = true) + val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) var outputData = dataset var numColsOutput = 0 @@ -1050,17 +1054,20 @@ class GeneralizedLinearRegressionModel private[ml] ( if (hasLinkPredictionCol) { val predLinkUDF = udf((features: Vector, offset: Double) => predictLink(features, offset)) outputData = outputData - .withColumn($(linkPredictionCol), predLinkUDF(col($(featuresCol)), offset)) + .withColumn($(linkPredictionCol), predLinkUDF(col($(featuresCol)), offset), + outputSchema($(linkPredictionCol)).metadata) numColsOutput += 1 } if ($(predictionCol).nonEmpty) { if (hasLinkPredictionCol) { val predUDF = udf((eta: Double) => familyAndLink.fitted(eta)) - outputData = outputData.withColumn($(predictionCol), predUDF(col($(linkPredictionCol)))) + outputData = outputData.withColumn($(predictionCol), predUDF(col($(linkPredictionCol))), + outputSchema($(predictionCol)).metadata) } else { val predUDF = udf((features: Vector, offset: Double) => predict(features, offset)) - outputData = outputData.withColumn($(predictionCol), predUDF(col($(featuresCol)), offset)) + outputData = outputData.withColumn($(predictionCol), predUDF(col($(featuresCol)), offset), + outputSchema($(predictionCol)).metadata) } numColsOutput += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 47f9e4bfb8333..d12e5daabebbf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -240,7 +240,7 @@ class IsotonicRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } @@ -248,12 +248,17 @@ class IsotonicRegressionModel private[ml] ( val idx = $(featureIndex) udf { features: Vector => oldModel.predict(features(idx)) } } - dataset.withColumn($(predictionCol), predict(col($(featuresCol)))) + dataset.withColumn($(predictionCol), predict(col($(featuresCol))), + outputSchema($(predictionCol)).metadata) } @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = false) + var outputSchema = validateAndTransformSchema(schema, fitting = false) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol)) + } + outputSchema } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index c3afab57a49c7..fa4dbbb47079f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -33,6 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructType /** * Random Forest @@ -192,8 +193,17 @@ class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(leafCol).nonEmpty) { + outputSchema = SchemaUtils.updateField(outputSchema, getLeafField($(leafCol))) + } + outputSchema + } + override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) var predictionColNames = Seq.empty[String] var predictionColumns = Seq.empty[Column] @@ -204,12 +214,14 @@ class RandomForestRegressionModel private[ml] ( val predictUDF = udf { features: Vector => bcastModel.value.predict(features) } predictionColNames :+= $(predictionCol) predictionColumns :+= predictUDF(col($(featuresCol))) + .as($(predictionCol), outputSchema($(predictionCol)).metadata) } if ($(leafCol).nonEmpty) { val leafUDF = udf { features: Vector => bcastModel.value.predictLeaf(features) } predictionColNames :+= $(leafCol) predictionColumns :+= leafUDF(col($(featuresCol))) + .as($(leafCol), outputSchema($(leafCol)).metadata) } if (predictionColNames.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 10895d4fd11d9..3009b733d4fb7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData @@ -89,6 +90,18 @@ private[spark] trait DecisionTreeModel { } } + private[ml] lazy val numLeave: Int = + leafIterator(rootNode).size + + private[ml] lazy val leafAttr = { + NominalAttribute.defaultAttr + .withNumValues(numLeave) + } + + private[ml] def getLeafField(leafCol: String) = { + leafAttr.withName(leafCol).toStructField() + } + @transient private lazy val leafIndices: Map[LeafNode, Int] = { leafIterator(rootNode).zipWithIndex.toMap } @@ -146,6 +159,10 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { val indices = trees.map(_.predictLeaf(features)) Vectors.dense(indices) } + + private[ml] def getLeafField(leafCol: String) = { + new AttributeGroup(leafCol, attrs = trees.map(_.leafAttr)).toStructField() + } } private[ml] object TreeEnsembleModel { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index c3894ebdd1785..752069daf8910 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.util +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.sql.types._ @@ -106,6 +107,91 @@ private[spark] object SchemaUtils { StructType(schema.fields :+ col) } + /** + * Update the size of a ML Vector column. If this column do not exist, append it. + * @param schema input schema + * @param colName column name + * @param size number of features + * @return new schema + */ + def updateAttributeGroupSize( + schema: StructType, + colName: String, + size: Int): StructType = { + require(size > 0) + val attrGroup = new AttributeGroup(colName, size) + val field = attrGroup.toStructField + updateField(schema, field, true) + } + + /** + * Update the number of values of an existing column. If this column do not exist, append it. + * @param schema input schema + * @param colName column name + * @param numValues number of values. + * @return new schema + */ + def updateNumValues( + schema: StructType, + colName: String, + numValues: Int): StructType = { + val attr = NominalAttribute.defaultAttr + .withName(colName) + .withNumValues(numValues) + val field = attr.toStructField + updateField(schema, field, true) + } + + /** + * Update the numeric meta of an existing column. If this column do not exist, append it. + * @param schema input schema + * @param colName column name + * @return new schema + */ + def updateNumeric( + schema: StructType, + colName: String): StructType = { + val attr = NumericAttribute.defaultAttr + .withName(colName) + val field = attr.toStructField + updateField(schema, field, true) + } + + /** + * Update the metadata of an existing column. If this column do not exist, append it. + * @param schema input schema + * @param field struct field + * @param overwriteMetadata whether to overwrite the metadata. If true, the metadata in the + * schema will be overwritten. If false, the metadata in `field` + * and `schema` will be merged to generate output metadata. + * @return new schema + */ + def updateField( + schema: StructType, + field: StructField, + overwriteMetadata: Boolean = true): StructType = { + if (schema.fieldNames.contains(field.name)) { + val newFields = schema.fields.map { f => + if (f.name == field.name) { + if (overwriteMetadata) { + field + } else { + val newMeta = new MetadataBuilder() + .withMetadata(field.metadata) + .withMetadata(f.metadata) + .build() + StructField(field.name, field.dataType, field.nullable, newMeta) + } + } else { + f + } + } + StructType(newFields) + } else { + appendColumn(schema, field) + } + } + /** * Check whether the given column in the schema is one of the supporting vector type: Vector, * Array[Float]. Array[Double] diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 3ebf8a83a892c..fd5af5b954150 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -249,6 +249,13 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val newTree = dt.fit(newData) + newTree.setLeafCol("predictedLeafId") + + val transformed = newTree.transform(newData) + checkNominalOnDF(transformed, "prediction", newTree.numClasses) + checkNominalOnDF(transformed, "predictedLeafId", newTree.numLeave) + checkVectorSizeOnDF(transformed, "rawPrediction", newTree.numClasses) + checkVectorSizeOnDF(transformed, "probability", newTree.numClasses) MLTestingUtils.checkCopyAndUids(dt, newTree) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index fdca71f8911c6..ffd4b5e6d3055 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -473,6 +473,13 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { .setCheckpointInterval(5) .setSeed(123) val model = gbt.fit(df) + model.setLeafCol("predictedLeafId") + + val transformed = model.transform(df) + checkNominalOnDF(transformed, "prediction", model.numClasses) + checkVectorSizeOnDF(transformed, "predictedLeafId", model.trees.length) + checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses) + checkVectorSizeOnDF(transformed, "probability", model.numClasses) model.trees.foreach (i => { assert(i.getMaxDepth === model.getMaxDepth) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index dc38f17d296f2..b23b4f4ac0d26 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -112,8 +112,13 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest { assert(lsvc.getFeaturesCol === "features") assert(lsvc.getPredictionCol === "prediction") assert(lsvc.getRawPredictionCol === "rawPrediction") + val model = lsvc.setMaxIter(5).fit(smallBinaryDataset) - model.transform(smallBinaryDataset) + val transformed = model.transform(smallBinaryDataset) + checkNominalOnDF(transformed, "prediction", model.numClasses) + checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses) + + transformed .select("label", "prediction", "rawPrediction") .collect() assert(model.getThreshold === 0.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 60c9cce6a4879..38bdfded9693e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -155,8 +155,14 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { assert(!lr.isDefined(lr.weightCol)) assert(lr.getFitIntercept) assert(lr.getStandardization) + val model = lr.fit(smallBinaryDataset) - model.transform(smallBinaryDataset) + val transformed = model.transform(smallBinaryDataset) + checkNominalOnDF(transformed, "prediction", model.numClasses) + checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses) + checkVectorSizeOnDF(transformed, "probability", model.numClasses) + + transformed .select("label", "probability", "prediction", "rawPrediction") .collect() assert(model.getThreshold === 0.5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index adffd83ab1bd1..024a3870d8bca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -81,6 +81,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest { assert(ovaModel.numClasses === numClasses) val transformedDataset = ovaModel.transform(dataset) + checkNominalOnDF(transformedDataset, "prediction", ovaModel.numClasses) + checkVectorSizeOnDF(transformedDataset, "rawPrediction", ovaModel.numClasses) // check for label metadata in prediction col val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 5958bfcf5ea6d..379d3bd128a5d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -242,6 +242,13 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) val model = rf.fit(df) + model.setLeafCol("predictedLeafId") + + val transformed = model.transform(df) + checkNominalOnDF(transformed, "prediction", model.numClasses) + checkVectorSizeOnDF(transformed, "predictedLeafId", model.trees.length) + checkVectorSizeOnDF(transformed, "rawPrediction", model.numClasses) + checkVectorSizeOnDF(transformed, "probability", model.numClasses) model.trees.foreach (i => { assert(i.getMaxDepth === model.getMaxDepth) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 5288595d2e239..7ac7b64adfdab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -174,6 +174,8 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { .setSeed(1) .fit(df) val predictionDf = model.transform(df) + checkNominalOnDF(predictionDf, "prediction", model.getK) + assert(predictionDf.select("prediction").distinct().count() == 3) val predictionsMap = predictionDf.collect().map(row => row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 133536f763f4e..e570693c90e6e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -71,10 +71,15 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { assert(gm.getK === 2) assert(gm.getFeaturesCol === "features") assert(gm.getPredictionCol === "prediction") + assert(gm.getProbabilityCol === "probability") assert(gm.getMaxIter === 100) assert(gm.getTol === 0.01) val model = gm.setMaxIter(1).fit(dataset) + val transformed = model.transform(dataset) + checkNominalOnDF(transformed, "prediction", model.weights.length) + checkVectorSizeOnDF(transformed, "probability", model.weights.length) + MLTestingUtils.checkCopyAndUids(gm, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e3c82fafca218..f6b1a8e9d6df3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -60,6 +60,9 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN) val model = kmeans.setMaxIter(1).fit(dataset) + val transformed = model.transform(dataset) + checkNominalOnDF(transformed, "prediction", model.clusterCenters.length) + MLTestingUtils.checkCopyAndUids(kmeans, model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 079dabb3665be..19645b517d79c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.jtransforms.dct.DoubleDCT_1D +import org.scalatest.exceptions.TestFailedException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} @@ -74,5 +75,24 @@ class DCTSuite extends MLTest with DefaultReadWriteTest { case Row(resultVec: Vector, wantedVec: Vector) => assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) } + + val vectorSize = dataset + .select("vec") + .map { case Row(vec: Vector) => vec.size } + .head() + + // Can not infer size of ouput vector, since no metadata is provided + intercept[TestFailedException] { + val transformed = transformer.transform(dataset) + checkVectorSizeOnDF(transformed, "resultVec", vectorSize) + } + + val dataset2 = new VectorSizeHint() + .setSize(vectorSize) + .setInputCol("vec") + .transform(dataset) + + val transformed2 = transformer.transform(dataset2) + checkVectorSizeOnDF(transformed2, "resultVec", vectorSize) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index 73b2b82daaf43..b4e144ea5ba5e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -68,6 +68,9 @@ class IDFSuite extends MLTest with DefaultReadWriteTest { .setOutputCol("idfValue") val idfModel = idfEst.fit(df) + val transformed = idfModel.transform(df) + checkVectorSizeOnDF(transformed, "idfValue", idfModel.idf.size) + MLTestingUtils.checkCopyAndUids(idfEst, idfModel) testTransformer[(Vector, Vector)](df, idfModel, "idfValue", "expected") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index 8dd0f0cb91e37..5de938fa40c4d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -44,6 +44,9 @@ class MaxAbsScalerSuite extends MLTest with DefaultReadWriteTest { .setOutputCol("scaled") val model = scaler.fit(df) + val transformed = model.transform(df) + checkVectorSizeOnDF(transformed, "scaled", model.maxAbs.size) + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { case Row(expectedVec: Vector, actualVec: Vector) => assert(expectedVec === actualVec, diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 2d965f2ca2c54..9b2b0c48f4f61 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -46,6 +46,9 @@ class MinMaxScalerSuite extends MLTest with DefaultReadWriteTest { .setMax(5) val model = scaler.fit(df) + val transformed = model.transform(df) + checkVectorSizeOnDF(transformed, "scaled", model.originalMin.size) + testTransformer[(Vector, Vector)](df, model, "expected", "scaled") { case Row(vector1: Vector, vector2: Vector) => assert(vector1 === vector2, "Transformed vector is different with expected.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index eff57f1223af4..d97df0050d74e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.ml.util.TestingUtils._ @@ -81,6 +83,22 @@ class NormalizerSuite extends MLTest with DefaultReadWriteTest { assertTypeOfVector(normalized, features) assertValues(normalized, expected) } + + val vectorSize = data.head.size + + // Can not infer size of output vector, since no metadata is provided + intercept[TestFailedException] { + val transformed = normalizer.transform(dataFrame) + checkVectorSizeOnDF(transformed, "normalized", vectorSize) + } + + val dataFrame2 = new VectorSizeHint() + .setSize(vectorSize) + .setInputCol("features") + .transform(dataFrame) + + val transformed2 = normalizer.transform(dataFrame2) + checkVectorSizeOnDF(transformed2, "normalized", vectorSize) } test("Normalization with setter") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 531b1d7c4d9f7..88c9867337e7c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -58,6 +58,8 @@ class PCASuite extends MLTest with DefaultReadWriteTest { .setK(3) val pcaModel = pca.fit(df) + val transformed = pcaModel.transform(df) + checkVectorSizeOnDF(transformed, "pca_features", pcaModel.getK) MLTestingUtils.checkCopyAndUids(pca, pcaModel) testTransformer[(Vector, Vector)](df, pcaModel, "pca_features", "expected") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index d28f1f4240ad0..11e1847ef235e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -61,6 +61,9 @@ class Word2VecSuite extends MLTest with DefaultReadWriteTest { .setSeed(42L) val model = w2v.fit(docDF) + val transformed = model.transform(docDF) + checkVectorSizeOnDF(transformed, "result", model.getVectorSize) + MLTestingUtils.checkCopyAndUids(w2v, model) // These expectations are just magic values, characterizing the current diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 978a3cbe54c1e..3e1e2ad6a7f55 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -24,10 +24,11 @@ import org.scalatest.Suite import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext, TestUtils} import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK import org.apache.spark.ml.{Model, PredictionModel, Transformer} +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.test.TestSparkSession @@ -64,6 +65,38 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => } } + private[ml] def checkVectorSizeOnDF( + dataframe: DataFrame, + vecColName: String, + vecSize: Int): Unit = { + import dataframe.sparkSession.implicits._ + val group = AttributeGroup.fromStructField(dataframe.schema(vecColName)) + assert(group.size === vecSize, + s"the vector size obtained from schema should be $vecSize, but got ${group.size}") + val sizeUDF = udf { vector: Vector => vector.size } + assert(dataframe.select(sizeUDF(col(vecColName))) + .as[Int] + .collect() + .forall(_ === vecSize)) + } + + private[ml] def checkNominalOnDF( + dataframe: DataFrame, + colName: String, + numValues: Int): Unit = { + import dataframe.sparkSession.implicits._ + val n = Attribute.fromStructField(dataframe.schema(colName)) match { + case binAttr: BinaryAttribute => Some(2) + case nomAttr: NominalAttribute => nomAttr.getNumValues + } + assert(n.isDefined && n.get === numValues, + s"the number of values obtained from schema should be $numValues, but got $n") + assert(dataframe.select(colName) + .as[Double] + .collect() + .forall(v => v === v.toInt && v >= 0 && v < numValues)) + } + private[util] def testTransformerOnStreamData[A : Encoder]( dataframe: DataFrame, transformer: Transformer,