diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index c305b36278e87..f8276de4f23d4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -234,11 +234,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { val features = selectorType match { case ChiSqSelector.KBest => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) case ChiSqSelector.Percentile => chiSqTestResult - .sortBy { case (res, _) => -res.statistic } + .sortBy { case (res, _) => res.pValue } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index dfebfc87ea1d3..6af06d82d671a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -38,10 +38,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext ) val preFilteredData = Seq( - Vectors.dense(0.0), - Vectors.dense(6.0), Vectors.dense(8.0), - Vectors.dense(5.0) + Vectors.dense(0.0), + Vectors.dense(0.0), + Vectors.dense(8.0) ) val df = sc.parallelize(data.zip(preFilteredData)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ec23a4aa7364d..ac702b4b7c69e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,10 +54,10 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), - LabeledPoint(1.0, Vectors.dense(Array(6.0))), - LabeledPoint(1.0, Vectors.dense(Array(8.0))), - LabeledPoint(2.0, Vectors.dense(Array(5.0)))) + Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(0.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21caa616ec..f1f9b6cba00e7 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2569,9 +2569,9 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures") >>> model = selector.fit(df) >>> model.transform(df).head().selectedFeatures - DenseVector([1.0]) + DenseVector([18.0]) >>> model.selectedFeatures - [3] + [2] >>> chiSqSelectorPath = temp_path + "/chi-sq-selector" >>> selector.save(chiSqSelectorPath) >>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 4aea81840a162..50ef7c7901c2c 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -288,15 +288,15 @@ class ChiSqSelector(object): ... ] >>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit( ... sc.parallelize(data)) >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0})) - SparseVector(1, {0: 6.0}) + SparseVector(1, {}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) - DenseVector([5.0]) + DenseVector([8.0]) >>> data = [ ... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})),