diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 0ee895a95a288..8336df8e34ae0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -1220,10 +1220,41 @@ class GeneralizedLinearRegressionSummary private[regression] ( private[regression] lazy val link: Link = familyLink.link + /** + * summary row containing: + * numInstances, weightSum, deviance, rss, weighted average of label - offset. + */ + private lazy val glrSummary = { + val devUDF = udf { (label: Double, pred: Double, weight: Double) => + family.deviance(label, pred, weight) + } + val devCol = sum(devUDF(label, prediction, weight)) + + val rssCol = if (model.getFamily.toLowerCase(Locale.ROOT) != Binomial.name && + model.getFamily.toLowerCase(Locale.ROOT) != Poisson.name) { + val rssUDF = udf { (label: Double, pred: Double, weight: Double) => + (label - pred) * (label - pred) * weight / family.variance(pred) + } + sum(rssUDF(label, prediction, weight)) + } else { + lit(Double.NaN) + } + + val avgCol = if (model.getFitIntercept && + (!model.hasOffsetCol || (model.hasOffsetCol && family == Gaussian && link == Identity))) { + sum((label - offset) * weight) / sum(weight) + } else { + lit(Double.NaN) + } + + predictions + .select(count(label), sum(weight), devCol, rssCol, avgCol) + .head() + } + /** Number of instances in DataFrame predictions. */ @Since("2.2.0") - lazy val numInstances: Long = predictions.count() - + lazy val numInstances: Long = glrSummary.getLong(0) /** * Name of features. If the name cannot be retrieved from attributes, @@ -1335,9 +1366,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ if (!model.hasOffsetCol || (model.hasOffsetCol && family == Gaussian && link == Identity)) { - val agg = predictions.agg(sum(weight.multiply( - label.minus(offset))), sum(weight)).first() - link.link(agg.getDouble(0) / agg.getDouble(1)) + link.link(glrSummary.getDouble(4)) } else { // Create empty feature column and fit intercept only model using param setting from model val featureNull = "feature_" + java.util.UUID.randomUUID.toString @@ -1362,12 +1391,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( * The deviance for the fitted model. */ @Since("2.0.0") - lazy val deviance: Double = { - predictions.select(label, prediction, weight).rdd.map { - case Row(label: Double, pred: Double, weight: Double) => - family.deviance(label, pred, weight) - }.sum() - } + lazy val deviance: Double = glrSummary.getDouble(2) /** * The dispersion of the fitted model. @@ -1381,14 +1405,14 @@ class GeneralizedLinearRegressionSummary private[regression] ( model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) { 1.0 } else { - val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0) + val rss = glrSummary.getDouble(3) rss / degreesOfFreedom } /** Akaike Information Criterion (AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { - val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0) + val weightSum = glrSummary.getDouble(1) val t = predictions.select( label, prediction, weight).rdd.map { case Row(label: Double, pred: Double, weight: Double) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d9f09c097292a..de559142a9261 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -1037,7 +1037,7 @@ class LinearRegressionSummary private[regression] ( } /** Number of instances in DataFrame predictions */ - lazy val numInstances: Long = predictions.count() + lazy val numInstances: Long = metrics.count /** Degrees of freedom */ @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index b697d2746ce7b..7938427544bd9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -131,4 +131,6 @@ class RegressionMetrics @Since("2.0.0") ( 1 - SSerr / SStot } } + + private[spark] def count: Long = summary.count }