Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
Expand Down Expand Up @@ -104,6 +104,27 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value)

/**
* Force to index label whether it is numeric or string type.
* Usually we index label only when it is string type.
* If the formula was used by classification algorithms,
* we can force to index label even it is numeric type by setting this param with true.
* Default: false.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@group param

* @group param
*/
@Since("2.1.0")
val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
"Force to index label whether it is numeric or string")
setDefault(forceIndexLabel -> false)

/** @group getParam */
@Since("2.1.0")
def getForceIndexLabel: Boolean = $(forceIndexLabel)

/** @group setParam */
@Since("2.1.0")
def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)

/** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = {
require(isDefined(formula), "Formula must be defined first.")
Expand Down Expand Up @@ -167,8 +188,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)

if (dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) {
if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) {
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
Expand All @@ -181,6 +202,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
require(!hasLabelCol(schema) || !$(forceIndexLabel),
"If label column already exists, forceIndexLabel can not be set with true.")
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}

test("label column already exists") {
test("label column already exists and forceIndexLabel was set with false") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = Seq((0, 1.0), (2, 2.0)).toDF("x", "y")
val model = formula.fit(original)
Expand All @@ -66,6 +66,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(resultSchema.toString == model.transform(original).schema.toString)
}

test("label column already exists but forceIndexLabel was set with true") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y").setForceIndexLabel(true)
val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y")
intercept[IllegalArgumentException] {
formula.fit(original)
}
}

test("label column already exists but is not numeric type") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("y")
val original = Seq((0, true), (2, false)).toDF("x", "y")
Expand Down Expand Up @@ -137,6 +145,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(result.collect() === expected.collect())
}

test("force to index label even it is numeric type") {
val formula = new RFormula().setFormula("id ~ a + b").setForceIndexLabel(true)
val original = spark.createDataFrame(
Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val expected = spark.createDataFrame(
Seq(
(1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
(1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
(0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
(1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
).toDF("id", "a", "b", "features", "label")
assert(result.collect() === expected.collect())
}

test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5))
Expand Down