Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1220,10 +1220,41 @@ class GeneralizedLinearRegressionSummary private[regression] (

private[regression] lazy val link: Link = familyLink.link

/**
* summary row containing:
* numInstances, weightSum, deviance, rss, weighted average of label - offset.
*/
private lazy val glrSummary = {
val devUDF = udf { (label: Double, pred: Double, weight: Double) =>
family.deviance(label, pred, weight)
}
val devCol = sum(devUDF(label, prediction, weight))

val rssCol = if (model.getFamily.toLowerCase(Locale.ROOT) != Binomial.name &&
model.getFamily.toLowerCase(Locale.ROOT) != Poisson.name) {
val rssUDF = udf { (label: Double, pred: Double, weight: Double) =>
(label - pred) * (label - pred) * weight / family.variance(pred)
}
sum(rssUDF(label, prediction, weight))
} else {
lit(Double.NaN)
}

val avgCol = if (model.getFitIntercept &&
(!model.hasOffsetCol || (model.hasOffsetCol && family == Gaussian && link == Identity))) {
sum((label - offset) * weight) / sum(weight)
} else {
lit(Double.NaN)
}

predictions
.select(count(label), sum(weight), devCol, rssCol, avgCol)
.head()
}

/** Number of instances in DataFrame predictions. */
@Since("2.2.0")
lazy val numInstances: Long = predictions.count()

lazy val numInstances: Long = glrSummary.getLong(0)

/**
* Name of features. If the name cannot be retrieved from attributes,
Expand Down Expand Up @@ -1335,9 +1366,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
*/
if (!model.hasOffsetCol ||
(model.hasOffsetCol && family == Gaussian && link == Identity)) {
val agg = predictions.agg(sum(weight.multiply(
label.minus(offset))), sum(weight)).first()
link.link(agg.getDouble(0) / agg.getDouble(1))
link.link(glrSummary.getDouble(4))
} else {
// Create empty feature column and fit intercept only model using param setting from model
val featureNull = "feature_" + java.util.UUID.randomUUID.toString
Expand All @@ -1362,12 +1391,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
* The deviance for the fitted model.
*/
@Since("2.0.0")
lazy val deviance: Double = {
predictions.select(label, prediction, weight).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
family.deviance(label, pred, weight)
}.sum()
}
lazy val deviance: Double = glrSummary.getDouble(2)

/**
* The dispersion of the fitted model.
Expand All @@ -1381,14 +1405,14 @@ class GeneralizedLinearRegressionSummary private[regression] (
model.getFamily.toLowerCase(Locale.ROOT) == Poisson.name) {
1.0
} else {
val rss = pearsonResiduals.agg(sum(pow(col("pearsonResiduals"), 2.0))).first().getDouble(0)
val rss = glrSummary.getDouble(3)
rss / degreesOfFreedom
}

/** Akaike Information Criterion (AIC) for the fitted model. */
@Since("2.0.0")
lazy val aic: Double = {
val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0)
val weightSum = glrSummary.getDouble(1)
val t = predictions.select(
label, prediction, weight).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ class LinearRegressionSummary private[regression] (
}

/** Number of instances in DataFrame predictions */
lazy val numInstances: Long = predictions.count()
lazy val numInstances: Long = metrics.count

/** Degrees of freedom */
@Since("2.2.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,6 @@ class RegressionMetrics @Since("2.0.0") (
1 - SSerr / SStot
}
}

private[spark] def count: Long = summary.count
}