Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest",

#' @rdname spark.survreg
#' @export
setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })

#' @rdname spark.svmLinear
#' @export
Expand Down
13 changes: 9 additions & 4 deletions R/pkg/R/mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) {
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
#' is the original probability of that class and t is the class's threshold.
#' @param weightCol The weight column name.
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
#' or the number of partitions are large, this param could be adjusted to a larger size.
#' This is an expert parameter. Default value should be good for most cases.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model.
#' @rdname spark.logit
Expand Down Expand Up @@ -245,19 +248,21 @@ function(object, path, overwrite = FALSE) {
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
tol = 1E-6, family = "auto", standardization = TRUE,
thresholds = 0.5, weightCol = NULL) {
thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")

if (is.null(weightCol)) {
weightCol <- ""
if (!is.null(weightCol) && weightCol == "") {
weightCol <- NULL
} else if (!is.null(weightCol)) {
weightCol <- as.character(weightCol)
}

jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
data@sdf, formula, as.numeric(regParam),
as.numeric(elasticNetParam), as.integer(maxIter),
as.numeric(tol), as.character(family),
as.logical(standardization), as.array(thresholds),
as.character(weightCol))
weightCol, as.integer(aggregationDepth))
new("LogisticRegressionModel", jobj = jobj)
})

Expand Down
24 changes: 16 additions & 8 deletions R/pkg/R/mllib_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
}

formula <- paste(deparse(formula), collapse = "")
if (is.null(weightCol)) {
weightCol <- ""
if (!is.null(weightCol) && weightCol == "") {
weightCol <- NULL
} else if (!is.null(weightCol)) {
weightCol <- as.character(weightCol)
}

# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, tolower(family$family), family$link,
tol, as.integer(maxIter), as.character(weightCol), regParam)
tol, as.integer(maxIter), weightCol, regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})

Expand Down Expand Up @@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) {
formula <- paste(deparse(formula), collapse = "")

if (is.null(weightCol)) {
weightCol <- ""
if (!is.null(weightCol) && weightCol == "") {
weightCol <- NULL
} else if (!is.null(weightCol)) {
weightCol <- as.character(weightCol)
}

jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
as.character(weightCol))
weightCol)
new("IsotonicRegressionModel", jobj = jobj)
})

Expand Down Expand Up @@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' Note that operator '.' is not supported currently.
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
#' or the number of partitions are large, this param could be adjusted to a larger size.
#' This is an expert parameter. Default value should be good for most cases.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.survreg} returns a fitted AFT survival regression model.
#' @rdname spark.survreg
#' @seealso survival: \url{https://cran.r-project.org/package=survival}
Expand All @@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' }
#' @note spark.survreg since 2.0.0
setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula) {
function(data, formula, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
"fit", formula, data@sdf, as.integer(aggregationDepth))
new("AFTSurvivalRegressionModel", jobj = jobj)
})

Expand Down
10 changes: 9 additions & 1 deletion R/pkg/inst/tests/testthat/test_mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,15 @@ test_that("spark.logit", {
df <- createDataFrame(data)
model <- spark.logit(df, label ~ feature)
prediction <- collect(select(predict(model, df), "prediction"))
expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0"))
expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))

# Test prediction with weightCol
weight <- c(2.0, 2.0, 2.0, 1.0, 1.0)
data2 <- as.data.frame(cbind(label, feature, weight))
df2 <- createDataFrame(data2)
model2 <- spark.logit(df2, label ~ feature, weightCol = "weight")
prediction2 <- collect(select(predict(model2, df2), "prediction"))
expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0"))
})

test_that("spark.mlp", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
}


def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
def fit(
formula: String,
data: DataFrame,
aggregationDepth: Int): AFTSurvivalRegressionWrapper = {

val (rewritedFormula, censorCol) = formulaRewrite(formula)

Expand All @@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)
.setFeaturesCol(rFormula.getFeaturesCol)
.setAggregationDepth(aggregationDepth)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper
.setFitIntercept(rFormula.hasIntercept)
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol)

if (weightCol != null) glr.setWeightCol(weightCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper
val isotonicRegression = new IsotonicRegression()
.setIsotonic(isotonic)
.setFeatureIndex(featureIndex)
.setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)

if (weightCol != null) isotonicRegression.setWeightCol(weightCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, isotonicRegression))
.fit(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper
family: String,
standardization: Boolean,
thresholds: Array[Double],
weightCol: String
weightCol: String,
aggregationDepth: Int
): LogisticRegressionWrapper = {

val rFormula = new RFormula()
Expand All @@ -119,17 +120,19 @@ private[r] object LogisticRegressionWrapper
.setFitIntercept(fitIntercept)
.setFamily(family)
.setStandardization(standardization)
.setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
.setAggregationDepth(aggregationDepth)

if (thresholds.length > 1) {
lr.setThresholds(thresholds)
} else {
lr.setThreshold(thresholds(0))
}

if (weightCol != null) lr.setWeightCol(weightCol)

val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
.setOutputCol(PREDICTED_LABEL_COL)
Expand Down