Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,8 @@ class LogisticRegression @Since("1.2.0") (
numClasses, isMultinomial))
// TODO: implement summary model for multinomial case
val m = if (!isMultinomial) {
val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
summaryModel.transform(dataset),
probabilityColName,
$(labelCol),
$(featuresCol),
objectiveHistory)
val logRegSummary = new BinaryLogisticRegressionTrainingSummary(dataset, model,
$(labelCol), $(featuresCol), objectiveHistory)
model.setSummary(Some(logRegSummary))
} else {
model
Expand Down Expand Up @@ -780,21 +775,6 @@ class LogisticRegressionModel private[spark] (
throw new SparkException("No training summary available for this LogisticRegressionModel")
}

/**
* If the probability column is set returns the current model and probability column,
* otherwise generates a new column and sets it as the probability column on a new copy
* of the current model.
*/
private[classification] def findSummaryModelAndProbabilityCol():
(LogisticRegressionModel, String) = {
$(probabilityCol) match {
case "" =>
val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
case p => (this, p)
}
}

private[classification]
def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
this.trainingSummary = summary
Expand All @@ -812,10 +792,7 @@ class LogisticRegressionModel private[spark] (
*/
@Since("2.0.0")
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
// Handle possible missing or invalid prediction columns
val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
probabilityColName, $(labelCol), $(featuresCol))
new BinaryLogisticRegressionSummary(dataset, this, $(labelCol), $(featuresCol))
}

/**
Expand Down Expand Up @@ -1125,22 +1102,22 @@ sealed trait LogisticRegressionSummary extends Serializable {
* :: Experimental ::
* Logistic regression training results.
*
* @param predictions dataframe output by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the probability of
* each class as a vector.
* @param dataset Dataset to be summarized.
* @param origModel Model to be summarized. This is copied to create an internal
* model which cannot be modified from outside.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
@Since("1.5.0")
class BinaryLogisticRegressionTrainingSummary private[classification] (
predictions: DataFrame,
probabilityCol: String,
dataset: Dataset[_],
origModel: LogisticRegressionModel,
labelCol: String,
featuresCol: String,
@Since("1.5.0") val objectiveHistory: Array[Double])
extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
extends BinaryLogisticRegressionSummary(dataset, origModel, labelCol, featuresCol)
with LogisticRegressionTrainingSummary {

}
Expand All @@ -1149,20 +1126,45 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* :: Experimental ::
* Binary Logistic regression results for a given model.
*
* @param predictions dataframe output by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the probability of
* each class as a vector.
* @param dataset Dataset to be summarized.
* @param origModel Model to be summarized. This is copied to create an internal
* model which cannot be modified from outside.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
@Experimental
@Since("1.5.0")
class BinaryLogisticRegressionSummary private[classification] (
@Since("1.5.0") @transient override val predictions: DataFrame,
@Since("1.5.0") override val probabilityCol: String,
dataset: Dataset[_],
origModel: LogisticRegressionModel,
@Since("1.5.0") override val labelCol: String,
@Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {

/**
* Field in "predictions" which gives the probability of each class as a vector.
* This is set to a new column name if the original model's `probabilityCol` is not set.
*/
@Since("1.5.0")
override val probabilityCol: String = {
if (origModel.isDefined(origModel.probabilityCol) && origModel.getProbabilityCol != "") {
origModel.getProbabilityCol
} else {
"probability_" + java.util.UUID.randomUUID().toString
}
}

/**
* Private copy of model to ensure Params are not modified outside this class.
* Coefficients is not a deep copy, but that is acceptable.
*
* NOTE: [[probabilityCol]] must be set correctly before the value of [[model]] is set,
* and [[model]] must be set before [[predictions]] is set!
*/
protected val model: LogisticRegressionModel =
origModel.copy(ParamMap.empty).setProbabilityCol(probabilityCol)

/** predictions output by the model's `transform` method */
@Since("1.5.0") @transient override val predictions: DataFrame = model.transform(dataset)

private val sparkSession = predictions.sparkSession
import sparkSession.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
// When it is trained by WeightedLeastSquares, training summary does not
// attach returned model.
val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept))
val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
summaryModel,
model.diagInvAtWA.toArray,
model.objectiveHistory)
val trainingSummary = new LinearRegressionTrainingSummary(dataset,
$(labelCol), $(featuresCol), lrModel, model.diagInvAtWA.toArray, model.objectiveHistory)

lrModel.setSummary(Some(trainingSummary))
instr.logSuccess(lrModel)
Expand Down Expand Up @@ -275,17 +268,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val intercept = yMean

val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()

val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
model,
Array(0D),
Array(0D))
val trainingSummary = new LinearRegressionTrainingSummary(dataset,
$(labelCol), $(featuresCol), model, Array(0D), Array(0D))

model.setSummary(Some(trainingSummary))
instr.logSuccess(model)
Expand Down Expand Up @@ -400,17 +384,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
if (handlePersistence) instances.unpersist()

val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()

val trainingSummary = new LinearRegressionTrainingSummary(
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
model,
Array(0D),
objectiveHistory)
val trainingSummary = new LinearRegressionTrainingSummary(dataset,
$(labelCol), $(featuresCol), model, Array(0D), objectiveHistory)

model.setSummary(Some(trainingSummary))
instr.logSuccess(model)
Expand Down Expand Up @@ -477,27 +452,9 @@ class LinearRegressionModel private[ml] (
*/
@Since("2.0.0")
def evaluate(dataset: Dataset[_]): LinearRegressionSummary = {
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
$(labelCol), $(featuresCol), summaryModel, Array(0D))
new LinearRegressionSummary(dataset, $(labelCol), $(featuresCol), this, Array(0D))
}

/**
* If the prediction column is set returns the current model and prediction column,
* otherwise generates a new column and sets it as the prediction column on a new copy
* of the current model.
*/
private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
}


override protected def predict(features: Vector): Double = {
dot(features, coefficients) + intercept
}
Expand Down Expand Up @@ -572,22 +529,20 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
* Linear regression training results. Currently, the training summary ignores the
* training weights except for the objective trace.
*
* @param predictions predictions output by the model's `transform` method.
* @param dataset Dataset to be summarized
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Since("1.5.0")
@Experimental
class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
dataset: Dataset[_],
labelCol: String,
featuresCol: String,
model: LinearRegressionModel,
diagInvAtWA: Array[Double],
val objectiveHistory: Array[Double])
extends LinearRegressionSummary(
predictions,
predictionCol,
dataset,
labelCol,
featuresCol,
model,
Expand All @@ -609,22 +564,44 @@ class LinearRegressionTrainingSummary private[regression] (
* :: Experimental ::
* Linear regression results evaluated on a dataset.
*
* @param predictions predictions output by the model's `transform` method.
* @param predictionCol Field in "predictions" which gives the predicted value of the label at
* each instance.
* @param labelCol Field in "predictions" which gives the true label of each instance.
* @param featuresCol Field in "predictions" which gives the features of each instance as a vector.
*/
@Since("1.5.0")
@Experimental
class LinearRegressionSummary private[regression] (
@transient val predictions: DataFrame,
val predictionCol: String,
dataset: Dataset[_],
val labelCol: String,
val featuresCol: String,
private val privateModel: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {

/**
* Field in "predictions" which gives the prediction value of each instance.
* This is set to a new column name if the original model's `predictionCol` is not set.
*/
@Since("1.5.0")
val predictionCol: String = {
if (privateModel.isDefined(privateModel.predictionCol) && privateModel.getPredictionCol != "") {
privateModel.getPredictionCol
} else {
"prediction_" + java.util.UUID.randomUUID().toString
}
}

/**
* Private copy of model to ensure Params are not modified outside this class.
* Coefficients is not a deep copy, but that is acceptable.
*
* NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set,
* and [[modelCopy]] must be set before [[predictions]] is set!
*/
protected val modelCopy: LinearRegressionModel =
privateModel.copy(ParamMap.empty).setPredictionCol(predictionCol)

/** predictions output by the model's `transform` method. */
@Since("1.5.0") @transient val predictions: DataFrame = modelCopy.transform(dataset)

@transient private val metrics = new RegressionMetrics(
predictions
.select(col(predictionCol), col(labelCol).cast(DoubleType))
Expand Down
5 changes: 5 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ object MimaExcludes {

// Exclude rules for 2.2.x
lazy val v22excludes = v21excludes ++ Seq(
// [SPARK-14985][ML] Update LinearRegression, LogisticRegression summary internals to handle
// model copy
ProblemFilters.exclude[DirectMissingMethodProblem]
("org.apache.spark.ml.regression.LinearRegressionTrainingSummary.this"),

// [SPARK-19652][UI] Do auth checks for REST API access.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"),
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"),
Expand Down