Skip to content

Commit c75f3cd

Browse files
committed
[SPARK-4409] Added JavaAPI Tests, and fixed a couple of bugs
1 parent d662f9d commit c75f3cd

File tree

3 files changed

+163
-25
lines changed

3 files changed

+163
-25
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ object SparseMatrix {
363363
var i = 0
364364
var nnz = 0
365365
var lastCol = -1
366-
367366
raw.foreach { v =>
368367
val r = i % numRows
369368
val c = (i - r) / numRows
@@ -378,7 +377,10 @@ object SparseMatrix {
378377
}
379378
i += 1
380379
}
381-
sCols.append(sparseA.length)
380+
while (numCols > lastCol){
381+
sCols.append(sparseA.length)
382+
lastCol += 1
383+
}
382384
new SparseMatrix(numRows, numCols, sCols.toArray, sRows.toArray, sparseA.toArray)
383385
}
384386

@@ -399,11 +401,11 @@ object SparseMatrix {
399401
s"0.0 < d < 1.0. Currently, density: $density")
400402
val rand = new XORShiftRandom(seed)
401403
val length = numRows * numCols
402-
val rawA = Array.fill(length)(0.0)
404+
val rawA = new Array[Double](length)
403405
var nnz = 0
404406
for (i <- 0 until length) {
405407
val p = rand.nextDouble()
406-
if (p < density) {
408+
if (p <= density) {
407409
rawA.update(i, rand.nextDouble())
408410
nnz += 1
409411
}
@@ -439,11 +441,11 @@ object SparseMatrix {
439441
s"0.0 < d < 1.0. Currently, density: $density")
440442
val rand = new XORShiftRandom(seed)
441443
val length = numRows * numCols
442-
val rawA = Array.fill(length)(0.0)
444+
val rawA = new Array[Double](length)
443445
var nnz = 0
444446
for (i <- 0 until length) {
445447
val p = rand.nextDouble()
446-
if (p < density) {
448+
if (p <= density) {
447449
rawA.update(i, rand.nextGaussian())
448450
nnz += 1
449451
}
@@ -476,21 +478,24 @@ object SparseMatrix {
476478
val values = sVec.values
477479
var i = 0
478480
var lastCol = -1
479-
val colPtrs = new ArrayBuffer[Int](n)
481+
val colPtrs = new ArrayBuffer[Int](n + 1)
480482
rows.foreach { r =>
481483
while (r != lastCol) {
482484
colPtrs.append(i)
483485
lastCol += 1
484486
}
485487
i += 1
486488
}
487-
colPtrs.append(n)
489+
while (n > lastCol) {
490+
colPtrs.append(i)
491+
lastCol += 1
492+
}
488493
new SparseMatrix(n, n, colPtrs.toArray, rows, values)
489494
case dVec: DenseVector =>
490495
val values = dVec.values
491496
var i = 0
492497
var nnz = 0
493-
val sVals = values.filter( v => v != 0.0)
498+
val sVals = values.filter(v => v != 0.0)
494499
var lastCol = -1
495500
val colPtrs = new ArrayBuffer[Int](n + 1)
496501
val sRows = new ArrayBuffer[Int](sVals.length)
@@ -687,10 +692,10 @@ object Matrices {
687692
* Horizontally concatenate a sequence of matrices. The returned matrix will be in the format
688693
* the matrices are supplied in. Supplying a mix of dense and sparse matrices will result in
689694
* a dense matrix.
690-
* @param matrices sequence of matrices
695+
* @param matrices array of matrices
691696
* @return a single `Matrix` composed of the matrices that were horizontally concatenated
692697
*/
693-
private[mllib] def horzCat(matrices: Seq[Matrix]): Matrix = {
698+
def horzcat(matrices: Array[Matrix]): Matrix = {
694699
if (matrices.size == 1) {
695700
return matrices(0)
696701
}
@@ -744,7 +749,7 @@ object Matrices {
744749
* @param matrices sequence of matrices
745750
* @return a single `Matrix` composed of the matrices that were horizontally concatenated
746751
*/
747-
private[mllib] def vertCat(matrices: Seq[Matrix]): Matrix = {
752+
def vertcat(matrices: Array[Matrix]): Matrix = {
748753
if (matrices.size == 1) {
749754
return matrices(0)
750755
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.linalg;
19+
20+
import static org.junit.Assert.*;
21+
import org.junit.Test;
22+
23+
import java.io.Serializable;
24+
25+
public class JavaMatricesSuite implements Serializable {
26+
27+
@Test
28+
public void randMatrixConstruction() {
29+
Matrix r = Matrices.rand(3, 4, 24);
30+
DenseMatrix dr = DenseMatrix.rand(3, 4, 24);
31+
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
32+
33+
Matrix rn = Matrices.randn(3, 4, 24);
34+
DenseMatrix drn = DenseMatrix.randn(3, 4, 24);
35+
assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
36+
37+
Matrix s = Matrices.sprand(3, 4, 0.5, 24);
38+
SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, 24);
39+
assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
40+
41+
Matrix sn = Matrices.sprandn(3, 4, 0.5, 24);
42+
SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, 24);
43+
assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
44+
}
45+
46+
@Test
47+
public void identityMatrixConstruction() {
48+
Matrix r = Matrices.eye(2);
49+
DenseMatrix dr = DenseMatrix.eye(2);
50+
SparseMatrix sr = SparseMatrix.speye(2);
51+
assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
52+
assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
53+
assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
54+
}
55+
56+
@Test
57+
public void diagonalMatrixConstruction() {
58+
Vector v = Vectors.dense(1.0, 0.0, 2.0);
59+
Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
60+
61+
Matrix m = Matrices.diag(v);
62+
Matrix sm = Matrices.diag(sv);
63+
DenseMatrix d = DenseMatrix.diag(v);
64+
DenseMatrix sd = DenseMatrix.diag(sv);
65+
SparseMatrix s = SparseMatrix.diag(v);
66+
SparseMatrix ss = SparseMatrix.diag(sv);
67+
68+
assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
69+
assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
70+
assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
71+
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
72+
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
73+
assertArrayEquals(s.values(), ss.values(), 0.0);
74+
assert(s.values().length == 2);
75+
assert(ss.values().length == 2);
76+
assert(s.colPtrs().length == 2);
77+
assert(ss.colPtrs().length == 2);
78+
}
79+
80+
@Test
81+
public void zerosMatrixConstruction() {
82+
Matrix z = Matrices.zeros(2, 2);
83+
Matrix one = Matrices.ones(2, 2);
84+
DenseMatrix dz = DenseMatrix.zeros(2, 2);
85+
DenseMatrix done = DenseMatrix.ones(2, 2);
86+
87+
assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
88+
assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
89+
assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
90+
assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
91+
}
92+
93+
@Test
94+
public void concatenateMatrices() {
95+
int m = 3;
96+
int n = 2;
97+
98+
SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, 42);
99+
DenseMatrix deMat1 = DenseMatrix.rand(m, n, 42);
100+
Matrix deMat2 = Matrices.eye(3);
101+
Matrix spMat2 = Matrices.speye(3);
102+
Matrix deMat3 = Matrices.eye(2);
103+
Matrix spMat3 = Matrices.speye(2);
104+
105+
Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
106+
Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
107+
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
108+
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
109+
110+
assert(deHorz1.numRows() == 3);
111+
assert(deHorz2.numRows() == 3);
112+
assert(deHorz3.numRows() == 3);
113+
assert(spHorz.numRows() == 3);
114+
assert(deHorz1.numCols() == 5);
115+
assert(deHorz2.numCols() == 5);
116+
assert(deHorz3.numCols() == 5);
117+
assert(spHorz.numCols() == 5);
118+
119+
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
120+
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
121+
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
122+
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
123+
124+
assert(deVert1.numRows() == 5);
125+
assert(deVert2.numRows() == 5);
126+
assert(deVert3.numRows() == 5);
127+
assert(spVert.numRows() == 5);
128+
assert(deVert1.numCols() == 2);
129+
assert(deVert2.numCols() == 2);
130+
assert(deVert3.numCols() == 2);
131+
assert(spVert.numCols() == 2);
132+
}
133+
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class MatricesSuite extends FunSuite {
132132
assert(deMat1.toArray === deMat2.toArray)
133133
}
134134

135-
test("horzCat, vertCat, eye, speye") {
135+
test("horzcat, vertcat, eye, speye") {
136136
val m = 3
137137
val n = 2
138138
val values = Array(1.0, 2.0, 4.0, 5.0)
@@ -147,10 +147,10 @@ class MatricesSuite extends FunSuite {
147147
val deMat3 = Matrices.eye(2)
148148
val spMat3 = Matrices.speye(2)
149149

150-
val spHorz = Matrices.horzCat(Seq(spMat1, spMat2))
151-
val deHorz1 = Matrices.horzCat(Seq(deMat1, deMat2))
152-
val deHorz2 = Matrices.horzCat(Seq(spMat1, deMat2))
153-
val deHorz3 = Matrices.horzCat(Seq(deMat1, spMat2))
150+
val spHorz = Matrices.horzcat(Array(spMat1, spMat2))
151+
val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2))
152+
val deHorz2 = Matrices.horzcat(Array(spMat1, deMat2))
153+
val deHorz3 = Matrices.horzcat(Array(deMat1, spMat2))
154154

155155
assert(deHorz1.numRows === 3)
156156
assert(deHorz2.numRows === 3)
@@ -179,17 +179,17 @@ class MatricesSuite extends FunSuite {
179179
assert(deHorz1(1, 4) === 0.0)
180180

181181
intercept[IllegalArgumentException] {
182-
Matrices.horzCat(Seq(spMat1, spMat3))
182+
Matrices.horzcat(Array(spMat1, spMat3))
183183
}
184184

185185
intercept[IllegalArgumentException] {
186-
Matrices.horzCat(Seq(deMat1, spMat3))
186+
Matrices.horzcat(Array(deMat1, spMat3))
187187
}
188188

189-
val spVert = Matrices.vertCat(Seq(spMat1, spMat3))
190-
val deVert1 = Matrices.vertCat(Seq(deMat1, deMat3))
191-
val deVert2 = Matrices.vertCat(Seq(spMat1, deMat3))
192-
val deVert3 = Matrices.vertCat(Seq(deMat1, spMat3))
189+
val spVert = Matrices.vertcat(Array(spMat1, spMat3))
190+
val deVert1 = Matrices.vertcat(Array(deMat1, deMat3))
191+
val deVert2 = Matrices.vertcat(Array(spMat1, deMat3))
192+
val deVert3 = Matrices.vertcat(Array(deMat1, spMat3))
193193

194194
assert(deVert1.numRows === 5)
195195
assert(deVert2.numRows === 5)
@@ -214,11 +214,11 @@ class MatricesSuite extends FunSuite {
214214
assert(deVert1(4, 1) === 1.0)
215215

216216
intercept[IllegalArgumentException] {
217-
Matrices.vertCat(Seq(spMat1, spMat2))
217+
Matrices.vertcat(Array(spMat1, spMat2))
218218
}
219219

220220
intercept[IllegalArgumentException] {
221-
Matrices.vertCat(Seq(deMat1, spMat2))
221+
Matrices.vertcat(Array(deMat1, spMat2))
222222
}
223223
}
224224
}

0 commit comments

Comments
 (0)