From 67b932a4a4a08002e32b892d56b2fdad0f7c6a70 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Wed, 31 Jul 2019 16:23:10 +0800 Subject: [PATCH 1/3] init --- .../org/apache/spark/ml/feature/MaxAbsScaler.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 90eceb0d61b40..c04af2657e422 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -121,13 +121,13 @@ class MaxAbsScalerModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - // TODO: this looks hack, we may have to handle sparse and dense vectors separately. - val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) - val reScale = udf { (vector: Vector) => - val brz = vector.asBreeze / maxAbsUnzero.asBreeze - Vectors.fromBreeze(brz) - } - dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + + val scale = maxAbs.toArray.map { v => if (v == 0) 1.0 else 1 / v } + val func = StandardScalerModel.getTransformFunc( + Array.empty, scale, false, true) + val transformer = udf(func) + + dataset.withColumn($(outputCol), transformer(col($(inputCol)))) } @Since("2.0.0") From 471b46c00b080e6af0f2dd4cc9f7f236d3a01607 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Wed, 31 Jul 2019 16:37:06 +0800 Subject: [PATCH 2/3] nit --- .../scala/org/apache/spark/ml/feature/MaxAbsScaler.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index c04af2657e422..0c7db0a976ac1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -25,9 +25,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.stat.Statistics -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -70,7 +69,7 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { transformSchema(dataset.schema, logging = true) - val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { + val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => OldVectors.fromML(v) } val summary = Statistics.colStats(input) From 8b923f99edb1aa85e275b22cc722e795f13166ed Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Wed, 31 Jul 2019 16:38:03 +0800 Subject: [PATCH 3/3] nit --- .../main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 0c7db0a976ac1..0f51e23323177 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -78,7 +78,7 @@ class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String) val n = minVals.length val maxAbs = Array.tabulate(n) { i => math.max(math.abs(minVals(i)), math.abs(maxVals(i))) } - copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs)).setParent(this)) + copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs).compressed).setParent(this)) } @Since("2.0.0")