diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 8e166ba0ff51..3fbc0958a0f1 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -657,6 +657,8 @@ class SparseVector @Since("2.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index dfbdaf19d374..4cd91afd6d7f 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -125,6 +125,13 @@ class VectorsSuite extends SparkMLFunSuite { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 723addc7150d..f063420bec14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -846,6 +846,8 @@ class SparseVector @Since("1.0.0") ( override def argmax: Int = { if (size == 0) { -1 + } else if (numActives == 0) { + 0 } else { // Find the max active entry. var maxIdx = indices(0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 71a3ceac1b94..6172cffee861 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -122,6 +122,13 @@ class VectorsSuite extends SparkFunSuite with Logging { val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0)) assert(vec8.argmax === 0) + + // Check for case when sparse vector is non-empty but the values are empty + val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec9.argmax === 0) + + val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector] + assert(vec10.argmax === 0) } test("vector equals") {