Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType}
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {

/**
* Centers the data with mean before scaling.
* Whether to center the data with mean before scaling.
* It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
*/
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
val withMean: BooleanParam = new BooleanParam(this, "withMean",
"Whether to center data with mean")

/** @group getParam */
def getWithMean: Boolean = $(withMean)

/**
* Scales the data to unit standard deviation.
* Whether to scale the data to unit standard deviation.
* Default: true
* @group param
*/
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
val withStd: BooleanParam = new BooleanParam(this, "withStd",
"Whether to scale the data to unit standard deviation")

/** @group getParam */
def getWithStd: Boolean = $(withStd)

setDefault(withMean -> false, withStd -> true)
}

/**
Expand All @@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM

def this() = this(Identifiable.randomUID("stdScal"))

setDefault(withMean -> false, withStd -> true)

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand Down Expand Up @@ -123,14 +131,6 @@ class StandardScalerModel private[ml] (
/** Mean of the StandardScalerModel */
val mean: Vector = scaler.mean

/** Whether to scale to unit standard deviation. */
@Since("1.6.0")
def getWithStd: Boolean = scaler.withStd

/** Whether to center data with mean. */
@Since("1.6.0")
def getWithMean: Boolean = scaler.withMean

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,19 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
assertResult(standardScaler3.transform(df3))
}

test("StandardScaler read/write") {
val t = new StandardScaler()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setWithStd(false)
.setWithMean(true)
testDefaultReadWrite(t)
}

test("StandardScalerModel read/write") {
val oldModel = new feature.StandardScalerModel(
Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.std === instance.std)
assert(newInstance.mean === instance.mean)
assert(newInstance.getWithStd === instance.getWithStd)
assert(newInstance.getWithMean === instance.getWithMean)
test("read/write") {
def checkModelData(model1: StandardScalerModel, model2: StandardScalerModel): Unit = {
assert(model1.mean === model2.mean)
assert(model1.std === model2.std)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to check mean and std which are parts of the model, withStd and withStd are params.

val allParams: Map[String, Any] = Map(
"inputCol" -> "features",
"outputCol" -> "standardized_features",
"withMean" -> true,
"withStd" -> true
)
val df = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
val standardScaler = new StandardScaler()
testEstimatorAndModelReadWrite(standardScaler, df, allParams, checkModelData)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withStd and withStd of StandardScalerModel must be inherited from StandardScaler, so we can not construct StandardScalerModel directly by specifying the two variables. Here we combine the original test cases into one with testEstimatorAndModelReadWrite which both test the estimator and model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not an ideal unit test for read/write because the model fitting part shouldn't be part of it, which is already covered by other tests. Constructing estimator and model directly can save some test time.

}