Skip to content

Commit eeda560

Browse files
GeorgeGeorge
authored andcommitted
Fixing SparseVector argmax function to ignore zero values while doing the calculation.
1 parent 4526acc commit eeda560

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ class SparseVector(
728728
var maxValue = values(0)
729729

730730
foreachActive { (i, v) =>
731-
if(v > maxValue){
731+
if(v != 0.0 && v > maxValue){
732732
maxIdx = i
733733
maxValue = v
734734
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ class VectorsSuite extends FunSuite {
7171
val vec2 = Vectors.dense(arr).asInstanceOf[DenseVector]
7272
val max = vec2.argmax
7373
assert(max === 3)
74+
75+
val vec3 = Vectors.dense(Array(-1.0, 0.0, -2.0, 1.0)).asInstanceOf[DenseVector]
76+
val max2 = vec3.argmax
77+
assert(max === 3)
7478
}
7579

7680
test("sparse to array") {
@@ -87,9 +91,10 @@ class VectorsSuite extends FunSuite {
8791
val max = vec2.argmax
8892
assert(max === 3)
8993

90-
val vec3 = Vectors.sparse(5,Array(1,3,4),Array(1.0,.5,.7))
94+
// check for case that sparse vector is created with a zero value in it by mistake
95+
val vec3 = Vectors.sparse(5,Array(0, 2, 4),Array(-1.0, 0.0, -.7))
9196
val max2 = vec3.argmax
92-
assert(max2 === 1)
97+
assert(max2 === 4)
9398
}
9499

95100
test("vector equals") {

0 commit comments

Comments
 (0)