Skip to content

Commit 0bfacd5

Browse files
committed
[SPARK-6090][MLLIB] add a basic BinaryClassificationMetrics to PySpark/MLlib
A simple wrapper around the Scala implementation. `DataFrame` is used for serialization/deserialization. Methods that return `RDD`s are not supported in this PR. davies If we recognize Scala's `Product`s in Py4J, we can easily add wrappers for Scala methods that returns `RDD[(Double, Double)]`. Is it easy to register serializer for `Product` in PySpark? Author: Xiangrui Meng <[email protected]> Closes #4863 from mengxr/SPARK-6090 and squashes the following commits: 009a3a3 [Xiangrui Meng] provide schema dcddab5 [Xiangrui Meng] add a basic BinaryClassificationMetrics to PySpark/MLlib
1 parent c9cfba0 commit 0bfacd5

File tree

4 files changed

+99
-0
lines changed

4 files changed

+99
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.Logging
2222
import org.apache.spark.SparkContext._
2323
import org.apache.spark.mllib.evaluation.binary._
2424
import org.apache.spark.rdd.{RDD, UnionRDD}
25+
import org.apache.spark.sql.DataFrame
2526

2627
/**
2728
* :: Experimental ::
@@ -53,6 +54,13 @@ class BinaryClassificationMetrics(
5354
*/
5455
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
5556

57+
/**
58+
* An auxiliary constructor taking a DataFrame.
59+
* @param scoreAndLabels a DataFrame with two double columns: score and label
60+
*/
61+
private[mllib] def this(scoreAndLabels: DataFrame) =
62+
this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
63+
5664
/** Unpersist intermediate RDDs used in the computation. */
5765
def unpersist() {
5866
cumulativeCounts.unpersist()

python/docs/pyspark.mllib.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ pyspark.mllib.clustering module
1616
:members:
1717
:undoc-members:
1818

19+
pyspark.mllib.evaluation module
20+
-------------------------------
21+
22+
.. automodule:: pyspark.mllib.evaluation
23+
:members:
24+
:undoc-members:
25+
1926
pyspark.mllib.feature module
2027
-------------------------------
2128

python/pyspark/mllib/evaluation.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark.mllib.common import JavaModelWrapper
19+
from pyspark.sql import SQLContext
20+
from pyspark.sql.types import StructField, StructType, DoubleType
21+
22+
23+
class BinaryClassificationMetrics(JavaModelWrapper):
24+
"""
25+
Evaluator for binary classification.
26+
27+
>>> scoreAndLabels = sc.parallelize([
28+
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
29+
>>> metrics = BinaryClassificationMetrics(scoreAndLabels)
30+
>>> metrics.areaUnderROC()
31+
0.70...
32+
>>> metrics.areaUnderPR()
33+
0.83...
34+
>>> metrics.unpersist()
35+
"""
36+
37+
def __init__(self, scoreAndLabels):
38+
"""
39+
:param scoreAndLabels: an RDD of (score, label) pairs
40+
"""
41+
sc = scoreAndLabels.ctx
42+
sql_ctx = SQLContext(sc)
43+
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
44+
StructField("score", DoubleType(), nullable=False),
45+
StructField("label", DoubleType(), nullable=False)]))
46+
java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
47+
java_model = java_class(df._jdf)
48+
super(BinaryClassificationMetrics, self).__init__(java_model)
49+
50+
def areaUnderROC(self):
51+
"""
52+
Computes the area under the receiver operating characteristic
53+
(ROC) curve.
54+
"""
55+
return self.call("areaUnderROC")
56+
57+
def areaUnderPR(self):
58+
"""
59+
Computes the area under the precision-recall curve.
60+
"""
61+
return self.call("areaUnderPR")
62+
63+
def unpersist(self):
64+
"""
65+
Unpersists intermediate RDDs used in the computation.
66+
"""
67+
self.call("unpersist")
68+
69+
70+
def _test():
71+
import doctest
72+
from pyspark import SparkContext
73+
import pyspark.mllib.evaluation
74+
globs = pyspark.mllib.evaluation.__dict__.copy()
75+
globs['sc'] = SparkContext('local[4]', 'PythonTest')
76+
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
77+
globs['sc'].stop()
78+
if failure_count:
79+
exit(-1)
80+
81+
82+
if __name__ == "__main__":
83+
_test()

python/run-tests

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ function run_mllib_tests() {
7575
echo "Run mllib tests ..."
7676
run_test "pyspark/mllib/classification.py"
7777
run_test "pyspark/mllib/clustering.py"
78+
run_test "pyspark/mllib/evaluation.py"
7879
run_test "pyspark/mllib/feature.py"
7980
run_test "pyspark/mllib/linalg.py"
8081
run_test "pyspark/mllib/rand.py"

0 commit comments

Comments
 (0)