Skip to content

Commit a5c8770

Browse files
MrBagojkbradley
authored andcommitted
[SPARK-20040][ML][PYTHON] pyspark wrapper for ChiSquareTest
## What changes were proposed in this pull request? A pyspark wrapper for spark.ml.stat.ChiSquareTest ## How was this patch tested? unit tests doctests Author: Bago Amirbekian <[email protected]> Closes #17421 from MrBago/chiSquareTestWrapper.
1 parent 7d432af commit a5c8770

File tree

5 files changed

+127
-12
lines changed

5 files changed

+127
-12
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def __hash__(self):
431431
"pyspark.ml.linalg.__init__",
432432
"pyspark.ml.recommendation",
433433
"pyspark.ml.regression",
434+
"pyspark.ml.stat",
434435
"pyspark.ml.tuning",
435436
"pyspark.ml.tests",
436437
],

mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ object ChiSquareTest {
4646
statistics: Vector)
4747

4848
/**
49-
* Conduct Pearson's independence test for every feature against the label across the input RDD.
50-
* For each feature, the (feature, label) pairs are converted into a contingency matrix for which
51-
* the Chi-squared statistic is computed. All label and feature values must be categorical.
49+
* Conduct Pearson's independence test for every feature against the label. For each feature, the
50+
* (feature, label) pairs are converted into a contingency matrix for which the Chi-squared
51+
* statistic is computed. All label and feature values must be categorical.
5252
*
5353
* The null hypothesis is that the occurrence of the outcomes is statistically independent.
5454
*

python/docs/pyspark.ml.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ pyspark.ml.regression module
6565
:undoc-members:
6666
:inherited-members:
6767

68+
pyspark.ml.stat module
69+
----------------------
70+
71+
.. automodule:: pyspark.ml.stat
72+
:members:
73+
:undoc-members:
74+
:inherited-members:
75+
6876
pyspark.ml.tuning module
6977
------------------------
7078

python/pyspark/ml/stat.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark import since, SparkContext
19+
from pyspark.ml.common import _java2py, _py2java
20+
from pyspark.ml.wrapper import _jvm
21+
22+
23+
class ChiSquareTest(object):
24+
"""
25+
.. note:: Experimental
26+
27+
Conduct Pearson's independence test for every feature against the label. For each feature,
28+
the (feature, label) pairs are converted into a contingency matrix for which the Chi-squared
29+
statistic is computed. All label and feature values must be categorical.
30+
31+
The null hypothesis is that the occurrence of the outcomes is statistically independent.
32+
33+
:param dataset:
34+
DataFrame of categorical labels and categorical features.
35+
Real-valued features will be treated as categorical for each distinct value.
36+
:param featuresCol:
37+
Name of features column in dataset, of type `Vector` (`VectorUDT`).
38+
:param labelCol:
39+
Name of label column in dataset, of any numerical type.
40+
:return:
41+
DataFrame containing the test result for every feature against the label.
42+
This DataFrame will contain a single Row with the following fields:
43+
- `pValues: Vector`
44+
- `degreesOfFreedom: Array[Int]`
45+
- `statistics: Vector`
46+
Each of these fields has one value per feature.
47+
48+
>>> from pyspark.ml.linalg import Vectors
49+
>>> from pyspark.ml.stat import ChiSquareTest
50+
>>> dataset = [[0, Vectors.dense([0, 0, 1])],
51+
... [0, Vectors.dense([1, 0, 1])],
52+
... [1, Vectors.dense([2, 1, 1])],
53+
... [1, Vectors.dense([3, 1, 1])]]
54+
>>> dataset = spark.createDataFrame(dataset, ["label", "features"])
55+
>>> chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')
56+
>>> chiSqResult.select("degreesOfFreedom").collect()[0]
57+
Row(degreesOfFreedom=[3, 1, 0])
58+
59+
.. versionadded:: 2.2.0
60+
61+
"""
62+
@staticmethod
63+
@since("2.2.0")
64+
def test(dataset, featuresCol, labelCol):
65+
"""
66+
Perform a Pearson's independence test using dataset.
67+
"""
68+
sc = SparkContext._active_spark_context
69+
javaTestObj = _jvm().org.apache.spark.ml.stat.ChiSquareTest
70+
args = [_py2java(sc, arg) for arg in (dataset, featuresCol, labelCol)]
71+
return _java2py(sc, javaTestObj.test(*args))
72+
73+
74+
if __name__ == "__main__":
75+
import doctest
76+
import pyspark.ml.stat
77+
from pyspark.sql import SparkSession
78+
79+
globs = pyspark.ml.stat.__dict__.copy()
80+
# The small batch size here ensures that we see multiple batches,
81+
# even in these small test examples:
82+
spark = SparkSession.builder \
83+
.master("local[2]") \
84+
.appName("ml.stat tests") \
85+
.getOrCreate()
86+
sc = spark.sparkContext
87+
globs['sc'] = sc
88+
globs['spark'] = spark
89+
90+
failure_count, test_count = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
91+
spark.stop()
92+
if failure_count:
93+
exit(-1)

python/pyspark/ml/tests.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
import tempfile
4242
import array as pyarray
4343
import numpy as np
44-
from numpy import (
45-
abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros)
46-
from numpy import sum as array_sum
44+
from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros
4745
import inspect
4846

4947
from pyspark import keyword_only, SparkContext
@@ -54,20 +52,19 @@
5452
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
5553
from pyspark.ml.feature import *
5654
from pyspark.ml.fpm import FPGrowth, FPGrowthModel
57-
from pyspark.ml.linalg import (
58-
DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT,
59-
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector)
55+
from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, \
56+
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors
6057
from pyspark.ml.param import Param, Params, TypeConverters
6158
from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed
6259
from pyspark.ml.recommendation import ALS
63-
from pyspark.ml.regression import (
64-
DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression)
60+
from pyspark.ml.regression import DecisionTreeRegressor, GeneralizedLinearRegression, \
61+
LinearRegression
62+
from pyspark.ml.stat import ChiSquareTest
6563
from pyspark.ml.tuning import *
6664
from pyspark.ml.wrapper import JavaParams, JavaWrapper
6765
from pyspark.serializers import PickleSerializer
6866
from pyspark.sql import DataFrame, Row, SparkSession
6967
from pyspark.sql.functions import rand
70-
from pyspark.sql.utils import IllegalArgumentException
7168
from pyspark.storagelevel import *
7269
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
7370

@@ -1741,6 +1738,22 @@ def test_new_java_array(self):
17411738
self.assertEqual(_java2py(self.sc, java_array), [])
17421739

17431740

1741+
class ChiSquareTestTests(SparkSessionTestCase):
1742+
1743+
def test_chisquaretest(self):
1744+
data = [[0, Vectors.dense([0, 1, 2])],
1745+
[1, Vectors.dense([1, 1, 1])],
1746+
[2, Vectors.dense([2, 1, 0])]]
1747+
df = self.spark.createDataFrame(data, ['label', 'feat'])
1748+
res = ChiSquareTest.test(df, 'feat', 'label')
1749+
# This line is hitting the collect bug described in #17218, commented for now.
1750+
# pValues = res.select("degreesOfFreedom").collect())
1751+
self.assertIsInstance(res, DataFrame)
1752+
fieldNames = set(field.name for field in res.schema.fields)
1753+
expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
1754+
self.assertTrue(all(field in fieldNames for field in expectedFields))
1755+
1756+
17441757
if __name__ == "__main__":
17451758
from pyspark.ml.tests import *
17461759
if xmlrunner:

0 commit comments

Comments
 (0)