Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
fitting: Boolean): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
SchemaUtils.checkNumericType(schema, $(censorCol))
SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
Expand Down Expand Up @@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* and put it in an RDD with strong types.
*/
protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
.rdd.map {
dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType),
col($(censorCol)).cast(DoubleType)).rdd.map {
case Row(features: Vector, label: Double, censor: Double) =>
AFTPoint(features, label, censor)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
*/
final val isotonic: BooleanParam =
new BooleanParam(this, "isotonic",
"whether the output sequence should be isotonic/increasing (true) or" +
"whether the output sequence should be isotonic/increasing (true) or " +
"antitonic/decreasing (false)")

/** @group getParam */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types._

class AFTSurvivalRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand Down Expand Up @@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite
}
}

test("should support all NumericType labels") {
test("should support all NumericType labels, and not support other types") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
aft, spark, isClassification = false) { (expected, actual) =>
Expand All @@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite
}
}

test("should support all NumericType censors, and not support other types") {
val df = spark.createDataFrame(Seq(
(0, Vectors.dense(0)),
(1, Vectors.dense(1)),
(2, Vectors.dense(2)),
(3, Vectors.dense(3)),
(4, Vectors.dense(4))
)).toDF("label", "features")
.withColumn("censor", lit(0.0))
val aft = new AFTSurvivalRegression().setMaxIter(1)
val expected = aft.fit(df)

val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0))
types.foreach { t =>
val actual = aft.fit(df.select(col("label"), col("features"),
col("censor").cast(t)))
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}

val dfWithStringCensors = spark.createDataFrame(Seq(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically I guess this could be part of checkNumericTypes similar to checking weight and label cols, but since it is specific to AFT this is ok.

(0, Vectors.dense(0, 2, 3), "0")
)).toDF("label", "features", "censor")
val thrown = intercept[IllegalArgumentException] {
Copy link
Contributor

@imatiach-msft imatiach-msft Mar 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you can wrap this in a withClue("Column censor must be of type NumericType but was actually of type StringType") {
...
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This place follows the implementation in MLTestingUtils.checkNumericTypes, so I prefer not to change this.

aft.fit(dfWithStringCensors)
}
assert(thrown.getMessage.contains(
"Column censor must be of type NumericType but was actually of type StringType"))
}

test("numerical stability of standardization") {
val trainer = new AFTSurvivalRegression()
val model1 = trainer.fit(datasetUnivariate)
Expand Down