-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19704][ML] AFTSurvivalRegression should support numeric censorCol #17034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._ | |
| import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.sql.{DataFrame, Row} | ||
| import org.apache.spark.sql.functions.{col, lit} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| class AFTSurvivalRegressionSuite | ||
| extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { | ||
|
|
@@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite | |
| } | ||
| } | ||
|
|
||
| test("should support all NumericType labels") { | ||
| test("should support all NumericType labels, and not support other types") { | ||
| val aft = new AFTSurvivalRegression().setMaxIter(1) | ||
| MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( | ||
| aft, spark, isClassification = false) { (expected, actual) => | ||
|
|
@@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite | |
| } | ||
| } | ||
|
|
||
| test("should support all NumericType censors, and not support other types") { | ||
| val df = spark.createDataFrame(Seq( | ||
| (0, Vectors.dense(0)), | ||
| (1, Vectors.dense(1)), | ||
| (2, Vectors.dense(2)), | ||
| (3, Vectors.dense(3)), | ||
| (4, Vectors.dense(4)) | ||
| )).toDF("label", "features") | ||
| .withColumn("censor", lit(0.0)) | ||
| val aft = new AFTSurvivalRegression().setMaxIter(1) | ||
| val expected = aft.fit(df) | ||
|
|
||
| val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0)) | ||
| types.foreach { t => | ||
| val actual = aft.fit(df.select(col("label"), col("features"), | ||
| col("censor").cast(t))) | ||
| assert(expected.intercept === actual.intercept) | ||
| assert(expected.coefficients === actual.coefficients) | ||
| } | ||
|
|
||
| val dfWithStringCensors = spark.createDataFrame(Seq( | ||
| (0, Vectors.dense(0, 2, 3), "0") | ||
| )).toDF("label", "features", "censor") | ||
| val thrown = intercept[IllegalArgumentException] { | ||
|
||
| aft.fit(dfWithStringCensors) | ||
| } | ||
| assert(thrown.getMessage.contains( | ||
| "Column censor must be of type NumericType but was actually of type StringType")) | ||
| } | ||
|
|
||
| test("numerical stability of standardization") { | ||
| val trainer = new AFTSurvivalRegression() | ||
| val model1 = trainer.fit(datasetUnivariate) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically I guess this could be part of
checkNumericTypessimilar to checking weight and label cols, but since it is specific to AFT this is ok.