From 6895c67d31be453a5c41b0885bb42cfe900a4885 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 3 Jan 2020 11:16:31 -0800 Subject: [PATCH 1/2] [SPARK-30418][ML] Make FM call super class method extractLabeledPoints --- .../apache/spark/ml/classification/FMClassifier.scala | 10 ++-------- .../org/apache/spark/ml/regression/FMRegressor.scala | 7 ++----- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index 15d71757a6724..1adacb9098114 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -204,14 +204,8 @@ class FMClassifier @Since("3.0.0") ( instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val data: RDD[(Double, OldVector)] = - dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => - require(label == 0 || label == 1, s"FMClassifier was given" + - s" dataset with invalid label $label. Labels must be in {0,1}; note that" + - s" FMClassifier currently only supports binary classification.") - (label, features) - } + val labeledPoint = extractLabeledPoints (dataset, numClasses) + val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label, x.features)) if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index 0bf1836edbd47..bef853d5220b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -419,11 +419,8 @@ class FMRegressor @Since("3.0.0") ( instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val data: RDD[(Double, OldVector)] = - dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { - case Row(label: Double, features: Vector) => - (label, features) - } + val labeledPoint = extractLabeledPoints (dataset) + val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label, x.features)) if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK) From b0c01c4a219fe27104977501545e1829394a9d7a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 4 Jan 2020 19:40:19 -0800 Subject: [PATCH 2/2] remove extra space --- .../scala/org/apache/spark/ml/classification/FMClassifier.scala | 2 +- .../main/scala/org/apache/spark/ml/regression/FMRegressor.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index 1adacb9098114..d511c1b5dda98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -204,7 +204,7 @@ class FMClassifier @Since("3.0.0") ( instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val labeledPoint = extractLabeledPoints (dataset, numClasses) + val labeledPoint = extractLabeledPoints(dataset, numClasses) val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label, x.features)) if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala index bef853d5220b2..0bdd0b4d9146b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala @@ -419,7 +419,7 @@ class FMRegressor @Since("3.0.0") ( instr.logNumFeatures(numFeatures) val handlePersistence = dataset.storageLevel == StorageLevel.NONE - val labeledPoint = extractLabeledPoints (dataset) + val labeledPoint = extractLabeledPoints(dataset) val data: RDD[(Double, OldVector)] = labeledPoint.map(x => (x.label, x.features)) if (handlePersistence) data.persist(StorageLevel.MEMORY_AND_DISK)