@@ -68,14 +68,14 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
6868# ' @examples
6969# ' \dontrun{
7070# ' sparkR.session()
71- # ' data(iris )
72- # ' df <- createDataFrame(iris )
73- # ' model <- spark.glm(df, Sepal_Length ~ Sepal_Width , family = "gaussian")
71+ # ' t <- as. data.frame(Titanic )
72+ # ' df <- createDataFrame(t )
73+ # ' model <- spark.glm(df, Freq ~ Sex + Age , family = "gaussian")
7474# ' summary(model)
7575# '
7676# ' # fitted values on training data
7777# ' fitted <- predict(model, df)
78- # ' head(select(fitted, "Sepal_Length ", "prediction"))
78+ # ' head(select(fitted, "Freq ", "prediction"))
7979# '
8080# ' # save fitted model to input path
8181# ' path <- "path/to/model"
@@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
102102 }
103103
104104 formula <- paste(deparse(formula ), collapse = " " )
105- if (is.null(weightCol )) {
106- weightCol <- " "
105+ if (! is.null(weightCol ) && weightCol == " " ) {
106+ weightCol <- NULL
107+ } else if (! is.null(weightCol )) {
108+ weightCol <- as.character(weightCol )
107109 }
108110
109111 # For known families, Gamma is upper-cased
110112 jobj <- callJStatic(" org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper" ,
111113 " fit" , formula , data @ sdf , tolower(family $ family ), family $ link ,
112- tol , as.integer(maxIter ), as.character( weightCol ) , regParam )
114+ tol , as.integer(maxIter ), weightCol , regParam )
113115 new(" GeneralizedLinearRegressionModel" , jobj = jobj )
114116 })
115117
@@ -135,9 +137,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
135137# ' @examples
136138# ' \dontrun{
137139# ' sparkR.session()
138- # ' data(iris )
139- # ' df <- createDataFrame(iris )
140- # ' model <- glm(Sepal_Length ~ Sepal_Width , df, family = "gaussian")
140+ # ' t <- as. data.frame(Titanic )
141+ # ' df <- createDataFrame(t )
142+ # ' model <- glm(Freq ~ Sex + Age , df, family = "gaussian")
141143# ' summary(model)
142144# ' }
143145# ' @note glm since 1.5.0
@@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
305307 function (data , formula , isotonic = TRUE , featureIndex = 0 , weightCol = NULL ) {
306308 formula <- paste(deparse(formula ), collapse = " " )
307309
308- if (is.null(weightCol )) {
309- weightCol <- " "
310+ if (! is.null(weightCol ) && weightCol == " " ) {
311+ weightCol <- NULL
312+ } else if (! is.null(weightCol )) {
313+ weightCol <- as.character(weightCol )
310314 }
311315
312316 jobj <- callJStatic(" org.apache.spark.ml.r.IsotonicRegressionWrapper" , " fit" ,
313317 data @ sdf , formula , as.logical(isotonic ), as.integer(featureIndex ),
314- as.character( weightCol ) )
318+ weightCol )
315319 new(" IsotonicRegressionModel" , jobj = jobj )
316320 })
317321
@@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
372376# ' @param formula a symbolic description of the model to be fitted. Currently only a few formula
373377# ' operators are supported, including '~', ':', '+', and '-'.
374378# ' Note that operator '.' is not supported currently.
379+ # ' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
380+ # ' or the number of partitions are large, this param could be adjusted to a larger size.
381+ # ' This is an expert parameter. Default value should be good for most cases.
382+ # ' @param ... additional arguments passed to the method.
375383# ' @return \code{spark.survreg} returns a fitted AFT survival regression model.
376384# ' @rdname spark.survreg
377385# ' @seealso survival: \url{https://cran.r-project.org/package=survival}
@@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
396404# ' }
397405# ' @note spark.survreg since 2.0.0
398406setMethod ("spark.survreg ", signature(data = "SparkDataFrame", formula = "formula"),
399- function (data , formula ) {
407+ function (data , formula , aggregationDepth = 2 ) {
400408 formula <- paste(deparse(formula ), collapse = " " )
401409 jobj <- callJStatic(" org.apache.spark.ml.r.AFTSurvivalRegressionWrapper" ,
402- " fit" , formula , data @ sdf )
410+ " fit" , formula , data @ sdf , as.integer( aggregationDepth ) )
403411 new(" AFTSurvivalRegressionModel" , jobj = jobj )
404412 })
405413
0 commit comments