Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object Correlation {

/**
* :: Experimental ::
* Compute the correlation matrix for the input RDD of Vectors using the specified method.
* Compute the correlation matrix for the input Dataset of Vectors using the specified method.
* Methods currently supported: `pearson` (default), `spearman`.
*
* @param dataset A dataset or a dataframe
Expand All @@ -56,14 +56,14 @@ object Correlation {
* Here is how to access the correlation coefficient:
* {{{
* val data: Dataset[Vector] = ...
* val Row(coeff: Matrix) = Statistics.corr(data, "value").head
* val Row(coeff: Matrix) = Correlation.corr(data, "value").head
* // coeff now contains the Pearson correlation matrix.
* }}}
*
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're here - below it says "cache the input RDD" but we that should be "the input Dataset"

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Fixed it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also since we are here as well, there is a reference to input RDD up above in the docstring.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right. fixed. :-)

* @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column
* and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
* avoid recomputing the common lineage.
* which is fairly costly. Cache the input Dataset before calling corr with `method = "spearman"`
* to avoid recomputing the common lineage.
*/
@Since("2.2.0")
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
Expand Down
61 changes: 61 additions & 0 deletions python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,67 @@ def test(dataset, featuresCol, labelCol):
return _java2py(sc, javaTestObj.test(*args))


class Correlation(object):
"""
.. note:: Experimental

Compute the correlation matrix for the input dataset of Vectors using the specified method.
Methods currently supported: `pearson` (default), `spearman`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the Scala documentation had a warning about caching being suggested when using Spearman, would it make sense to copy this warning over as well?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Fixed.


.. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I picked up that the doc gen will fail here - there needs to be 2 spaces before the start of each subsequent line, like this:

.. note:: For Spearman, a rank correlation, we need to create an RDD[Double] for each column
  and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
  which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'`
  to avoid recomputing the common lineage.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah. ok. fixed. see if this time it's ok.

and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector],
which is fairly costly. Cache the input Dataset before calling corr with `method = 'spearman'`
to avoid recomputing the common lineage.

:param dataset:
A dataset or a dataframe.
:param column:
The name of the column of vectors for which the correlation coefficient needs
to be computed. This must be a column of the dataset, and it must contain
Vector objects.
:param method:
String specifying the method to use for computing correlation.
Supported: `pearson` (default), `spearman`.
:return:
A dataframe that contains the correlation matrix of the column of vectors. This
dataframe contains a single row and a single column of name
'$METHODNAME($COLUMN)'.

>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.stat import Correlation
>>> dataset = [[Vectors.dense([1, 0, 0, -2])],
... [Vectors.dense([4, 5, 0, 3])],
... [Vectors.dense([6, 7, 0, 8])],
... [Vectors.dense([9, 0, 0, 1])]]
>>> dataset = spark.createDataFrame(dataset, ['features'])
>>> pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()[0][0]
>>> print(str(pearsonCorr).replace('nan', 'NaN'))
DenseMatrix([[ 1. , 0.0556..., NaN, 0.4004...],
[ 0.0556..., 1. , NaN, 0.9135...],
[ NaN, NaN, 1. , NaN],
[ 0.4004..., 0.9135..., NaN, 1. ]])
>>> spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()[0][0]
>>> print(str(spearmanCorr).replace('nan', 'NaN'))
DenseMatrix([[ 1. , 0.1054..., NaN, 0.4 ],
[ 0.1054..., 1. , NaN, 0.9486... ],
[ NaN, NaN, 1. , NaN],
[ 0.4 , 0.9486... , NaN, 1. ]])

.. versionadded:: 2.2.0

"""
@staticmethod
@since("2.2.0")
def corr(dataset, column, method="pearson"):
"""
Compute the correlation matrix with specified method using dataset.
"""
sc = SparkContext._active_spark_context
javaCorrObj = _jvm().org.apache.spark.ml.stat.Correlation
args = [_py2java(sc, arg) for arg in (dataset, column, method)]
return _java2py(sc, javaCorrObj.corr(*args))


if __name__ == "__main__":
import doctest
import pyspark.ml.stat
Expand Down