Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.col

/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
Expand Down Expand Up @@ -120,8 +121,10 @@ private class MyLogisticRegression(override val uid: String)

// This method is used by fit()
override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
val oldDataset = extractLabeledPoints(dataset)
// Extract columns from data.
val oldDataset = dataset.select(col($(labelCol)).cast("double"), col($(featuresCol)))
.rdd
.map { case Row(l: Double, f: Vector) => LabeledPoint(l, f) }

// Do learning to estimate the coefficients vector.
val numFeatures = oldDataset.take(1)(0).features.size
Expand Down
51 changes: 2 additions & 49 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
package org.apache.spark.ml

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

Expand Down Expand Up @@ -63,40 +60,6 @@ private[ml] trait PredictorParams extends Params
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*/
protected def extractInstances(dataset: Dataset[_]): RDD[Instance] = {
val w = this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType)))
} else {
lit(1.0)
}
}

dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
* Validate the output instances with the given function.
*/
protected def extractInstances(
dataset: Dataset[_],
validateInstance: Instance => Unit): RDD[Instance] = {
extractInstances(dataset).map { instance =>
validateInstance(instance)
instance
}
}
}

/**
Expand Down Expand Up @@ -176,16 +139,6 @@ abstract class Predictor[
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, featuresDataType)
}

/**
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@

package org.apache.spark.ml.classification

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}

Expand All @@ -44,23 +40,6 @@ private[spark] trait ClassifierParams
val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
}

/**
* Extract [[labelCol]], weightCol(if any) and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*/
protected def extractInstances(
dataset: Dataset[_],
numClasses: Int): RDD[Instance] = {
val validateInstance = (instance: Instance) => {
val label = instance.label
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}
extractInstances(dataset, validateInstance)
}
}

/**
Expand All @@ -81,89 +60,6 @@ abstract class Classifier[
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]

// TODO: defaultEvaluator (follow-up PR)

/**
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
* and features (`Vector`).
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
* @note Throws `SparkException` if any label is a non-integer or is negative
*/
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
validateNumClasses(numClasses)
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
validateLabel(label, numClasses)
LabeledPoint(label, features)
}
}

/**
* Validates that number of classes is greater than zero.
*
* @param numClasses Number of classes label can take.
*/
protected def validateNumClasses(numClasses: Int): Unit = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
}

/**
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*
* @param label The label to validate.
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
*/
protected def validateLabel(label: Double, numClasses: Int): Unit = {
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}

/**
* Get the number of classes. This looks in column metadata first, and if that is missing,
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
* by finding the maximum label value.
*
* Label validation (ensuring all labels are integers >= 0) needs to be handled elsewhere,
* such as in `extractLabeledPoints()`.
*
* @param dataset Dataset which contains a column [[labelCol]]
* @param maxNumClasses Maximum number of classes allowed when inferred from data. If numClasses
* is specified in the metadata, then maxNumClasses is ignored.
* @return number of classes
* @throws IllegalArgumentException if metadata does not specify numClasses, and the
* actual numClasses exceeds maxNumClasses
*/
protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None =>
// Get number of classes from dataset itself.
val maxLabelRow: Array[Row] = dataset
.select(max(checkClassificationLabels($(labelCol), Some(maxNumClasses))))
.take(1)
if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
throw new SparkException("ML algorithm was given empty dataset.")
}
val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
val numClasses = maxDoubleLabel.toInt + 1
require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
s" StringIndexer to the label column.")
logInfo(this.getClass.getCanonicalName + s" inferred $numClasses classes for" +
s" labelCol=$labelCol since numClasses was not specified in the column metadata.")
numClasses
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses = getNumClasses(dataset)
val numClasses = getNumClasses(dataset, $(labelCol))

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
validateNumClasses(numClasses)

val instances = dataset.select(
checkClassificationLabels($(labelCol), Some(numClasses)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class FMClassifier @Since("3.0.0") (
miniBatchFraction, initStd, maxIter, stepSize, tol, solver, thresholds)
instr.logNumClasses(numClasses)

val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
val numFeatures = getNumFeatures(dataset, $(featuresCol))
instr.logNumFeatures(numFeatures)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DatasetUtils._
import org.apache.spark.ml.util.DatasetUtils.extractInstances
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
Expand Down Expand Up @@ -169,21 +168,12 @@ class GBTClassifier @Since("1.4.0") (

override protected def train(
dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>

def extractInstances(df: Dataset[_]) = {
df.select(
checkClassificationLabels($(labelCol), Some(2)),
checkNonNegativeWeights(get(weightCol)),
checkNonNanVectors($(featuresCol))
).rdd.map { case Row(l: Double, w: Double, v: Vector) => Instance(l, w, v) }
}

val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
val (trainDataset, validationDataset) = if (withValidation) {
(extractInstances(dataset.filter(not(col($(validationIndicatorCol))))),
extractInstances(dataset.filter(col($(validationIndicatorCol)))))
(extractInstances(this, dataset.filter(not(col($(validationIndicatorCol)))), Some(2)),
extractInstances(this, dataset.filter(col($(validationIndicatorCol))), Some(2)))
} else {
(extractInstances(dataset), null)
(extractInstances(this, dataset, Some(2)), null)
}

val numClasses = 2
Expand Down Expand Up @@ -390,7 +380,7 @@ class GBTClassificationModel private[ml](
*/
@Since("2.4.0")
def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
val data = extractInstances(dataset)
val data = extractInstances(this, dataset, Some(2))
GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
OldAlgo.Classification)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class RandomForestClassifier @Since("1.4.0") (
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
val numClasses = getNumClasses(dataset, $(labelCol))

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ class GaussianMixture @Since("2.0.0") (
val spark = dataset.sparkSession
import spark.implicits._

val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol))
val numFeatures = getNumFeatures(dataset, $(featuresCol))
require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Since
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, MetadataUtils, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -129,8 +128,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
dataset.select(
col($(rawPredictionCol)),
col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map {
DatasetUtils.checkNonNegativeWeights(get(weightCol))
).rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), label, weight)
case Row(rawPrediction: Double, label: Double, weight: Double) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Since
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType

/**
* Evaluator for clustering results.
Expand Down Expand Up @@ -130,18 +128,13 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
SchemaUtils.checkNumericType(schema, $(weightCol))
}

val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)

val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
val df = if (!isDefined(weightCol) || $(weightCol).isEmpty) {
dataset.select(col($(predictionCol)),
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
lit(1.0).as(weightColName))
} else {
dataset.select(col($(predictionCol)),
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
checkNonNegativeWeight(col(weightColName).cast(DoubleType)))
}
val df = dataset.select(
col($(predictionCol)),
DatasetUtils.columnToVector(dataset, $(featuresCol))
.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
DatasetUtils.checkNonNegativeWeights(get(weightCol))
.as(if (!isDefined(weightCol)) "weightCol" else $(weightCol))
)

val metrics = new ClusteringMetrics(df)
metrics.setDistanceMeasure($(distanceMeasure))
Expand Down
Loading