-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-20076][ML][PySpark] Add Python interface for ml.stats.Correlation #17494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e129e06
a684ac8
fbcc1fe
5d9d70f
fd76901
5d04326
601d9eb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,6 +71,67 @@ def test(dataset, featuresCol, labelCol): | |
| return _java2py(sc, javaTestObj.test(*args)) | ||
|
|
||
|
|
||
| class Correlation(object): | ||
| """ | ||
| .. note:: Experimental | ||
|
|
||
| Compute the correlation matrix for the input dataset of Vectors using the specified method. | ||
| Methods currently supported: `pearson` (default), `spearman`. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the Scala documentation had a warning about caching being suggested when using Spearman, would it make sense to copy this warning over as well?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Fixed. |
||
|
|
||
| .. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I picked up that the doc gen will fail here - there needs to be 2 spaces before the start of each subsequent line, like this:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah. ok. fixed. see if this time it's ok. |
||
| and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], | ||
| which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'` | ||
| to avoid recomputing the common lineage. | ||
|
|
||
| :param dataset: | ||
| A dataset or a dataframe. | ||
| :param column: | ||
| The name of the column of vectors for which the correlation coefficient needs | ||
| to be computed. This must be a column of the dataset, and it must contain | ||
| Vector objects. | ||
| :param method: | ||
| String specifying the method to use for computing correlation. | ||
| Supported: `pearson` (default), `spearman`. | ||
| :return: | ||
| A dataframe that contains the correlation matrix of the column of vectors. This | ||
| dataframe contains a single row and a single column of name | ||
| '$METHODNAME($COLUMN)'. | ||
|
|
||
| >>> from pyspark.ml.linalg import Vectors | ||
| >>> from pyspark.ml.stat import Correlation | ||
| >>> dataset = [[Vectors.dense([1, 0, 0, -2])], | ||
| ... [Vectors.dense([4, 5, 0, 3])], | ||
| ... [Vectors.dense([6, 7, 0, 8])], | ||
| ... [Vectors.dense([9, 0, 0, 1])]] | ||
| >>> dataset = spark.createDataFrame(dataset, ['features']) | ||
| >>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0] | ||
| >>> print(str(pearsonCorr).replace('nan', 'NaN')) | ||
| DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...], | ||
| [ 0.0556..., 1. , NaN, 0.9135...], | ||
| [ NaN, NaN, 1. , NaN], | ||
| [ 0.4004..., 0.9135..., NaN, 1. ]]) | ||
| >>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0] | ||
| >>> print(str(spearmanCorr).replace('nan', 'NaN')) | ||
| DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ], | ||
| [ 0.1054..., 1. , NaN, 0.9486... ], | ||
| [ NaN, NaN, 1. , NaN], | ||
| [ 0.4 , 0.9486... , NaN, 1. ]]) | ||
|
|
||
| .. versionadded:: 2.2.0 | ||
|
|
||
| """ | ||
| @staticmethod | ||
| @since("2.2.0") | ||
| def corr(dataset, column, method="pearson"): | ||
| """ | ||
| Compute the correlation matrix with specified method using dataset. | ||
| """ | ||
| sc = SparkContext._active_spark_context | ||
| javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation | ||
| args = [_py2java(sc, arg) for arg in (dataset, column, method)] | ||
| return _java2py(sc, javaCorrObj.corr(*args)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| import pyspark.ml.stat | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we're here - below it says "cache the input RDD" but we that should be "the input Dataset"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also since we are here as well, there is a reference to input RDD up above in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, right. fixed. :-)