Skip to content

Commit c1c6c22

Browse files
committed
add AreaUnderCurve
1 parent 65461b2 commit c1c6c22

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.apache.spark.rdd.RDD
21+
22+
/**
23+
* Computes the area under the curve (AUC) using the trapezoidal rule.
24+
*/
25+
object AreaUnderCurve {
26+
27+
private def trapezoid(points: Array[(Double, Double)]): Double = {
28+
require(points.length == 2)
29+
(points(1)._1 - points(0)._1) * (points(1)._2 + points(0)._2 ) / 2.0
30+
}
31+
32+
/**
33+
* Returns the area under the given curve.
34+
*
35+
* @param curve a RDD of ordered 2D points stored in pairs representing a curve
36+
*/
37+
def of(curve: RDD[(Double, Double)]): Double = {
38+
curve.sliding(2).aggregate(0.0)(
39+
seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
40+
combOp = (_ + _)
41+
)
42+
}
43+
44+
/**
45+
* Returns the area under the given curve.
46+
*
47+
* @param curve an iterable of ordered 2D points stored in pairs representing a curve
48+
*/
49+
def of(curve: Iterable[(Double, Double)]): Double = {
50+
curve.sliding(2).map(_.toArray).filter(_.size == 2).aggregate(0.0)(
51+
seqop = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
52+
combop = (_ + _)
53+
)
54+
}
55+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.evaluation
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.mllib.util.LocalSparkContext
23+
24+
class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
25+
26+
test("auc computation") {
27+
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
28+
val auc = 4.0
29+
assert(AreaUnderCurve.of(curve) === auc)
30+
val rddCurve = sc.parallelize(curve, 2)
31+
assert(AreaUnderCurve.of(rddCurve) == auc)
32+
}
33+
34+
test("auc of an empty curve") {
35+
val curve = Seq.empty[(Double, Double)]
36+
assert(AreaUnderCurve.of(curve) === 0.0)
37+
val rddCurve = sc.parallelize(curve, 2)
38+
assert(AreaUnderCurve.of(rddCurve) === 0.0)
39+
}
40+
41+
test("auc of a curve with a single point") {
42+
val curve = Seq((1.0, 1.0))
43+
assert(AreaUnderCurve.of(curve) === 0.0)
44+
val rddCurve = sc.parallelize(curve, 2)
45+
assert(AreaUnderCurve.of(rddCurve) === 0.0)
46+
}
47+
}

0 commit comments

Comments
 (0)