@@ -200,3 +200,52 @@ test_that("naiveBayes", {
200200 expect_equal(as.character(predict(m , t1 [1 , ])), " Yes" )
201201 }
202202})
203+
204+ test_that(" survreg" , {
205+ # R code to reproduce the result.
206+ #
207+ # ' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
208+ # ' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
209+ # ' library(survival)
210+ # ' model <- survreg(Surv(time, status) ~ x + sex, rData)
211+ # ' summary(model)
212+ # ' predict(model, data)
213+ #
214+ # -- output of 'summary(model)'
215+ #
216+ # Value Std. Error z p
217+ # (Intercept) 1.315 0.270 4.88 1.07e-06
218+ # x -0.190 0.173 -1.10 2.72e-01
219+ # sex -0.253 0.329 -0.77 4.42e-01
220+ # Log(scale) -1.160 0.396 -2.93 3.41e-03
221+ #
222+ # -- output of 'predict(model, data)'
223+ #
224+ # 1 2 3 4 5 6 7
225+ # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
226+ #
227+ data <- list (list (4 , 1 , 0 , 0 ), list (3 , 1 , 2 , 0 ), list (1 , 1 , 1 , 0 ),
228+ list (1 , 0 , 1 , 0 ), list (2 , 1 , 1 , 1 ), list (2 , 1 , 0 , 1 ), list (3 , 0 , 0 , 1 ))
229+ df <- createDataFrame(sqlContext , data , c(" time" , " status" , " x" , " sex" ))
230+ model <- survreg(Surv(time , status ) ~ x + sex , df )
231+ stats <- summary(model )
232+ coefs <- as.vector(stats $ coefficients [, 1 ])
233+ rCoefs <- c(1.3149571 , - 0.1903409 , - 0.2532618 , - 1.1599800 )
234+ expect_equal(coefs , rCoefs , tolerance = 1e-4 )
235+ expect_true(all(
236+ rownames(stats $ coefficients ) ==
237+ c(" (Intercept)" , " x" , " sex" , " Log(scale)" )))
238+ p <- collect(select(predict(model , df ), " prediction" ))
239+ expect_equal(p $ prediction , c(3.724591 , 2.545368 , 3.079035 , 3.079035 ,
240+ 2.390146 , 2.891269 , 2.891269 ), tolerance = 1e-4 )
241+
242+ # Test survival::survreg
243+ if (requireNamespace(" survival" , quietly = TRUE )) {
244+ rData <- list (time = c(4 , 3 , 1 , 1 , 2 , 2 , 3 ), status = c(1 , 1 , 1 , 0 , 1 , 1 , 0 ),
245+ x = c(0 , 2 , 1 , 1 , 1 , 0 , 0 ), sex = c(0 , 0 , 0 , 0 , 1 , 1 , 1 ))
246+ expect_that(
247+ model <- survival :: survreg(formula = survival :: Surv(time , status ) ~ x + sex , data = rData ),
248+ not(throws_error()))
249+ expect_equal(predict(model , rData )[[1 ]], 3.724591 , tolerance = 1e-4 )
250+ }
251+ })
0 commit comments