Skip to content

Commit ef853f9

Browse files
committed
update python linalg api and small fixes
1 parent ec9df6a commit ef853f9

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

python/pyspark/mllib/linalg.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def toArray(self):
7676

7777

7878
class DenseVector(Vector):
79+
"""
80+
A dense vector represented by a value array.
81+
"""
7982
def __init__(self, ar):
8083
if not isinstance(ar, array.array):
8184
ar = array.array('d', ar)
@@ -101,6 +104,7 @@ def dot(self, other):
101104
>>> dense.dot(np.array(range(1, 3)))
102105
5.0
103106
"""
107+
assert len(self) == len(other), "vector sizes mismatch"
104108
if isinstance(other, SparseVector):
105109
return other.dot(self)
106110
elif _have_scipy and scipy.sparse.issparse(other):
@@ -127,6 +131,7 @@ def squared_distance(self, other):
127131
>>> dense1.squared_distance(sparse1)
128132
2.0
129133
"""
134+
assert len(self) == len(other), "vector sizes mismatch"
130135
if isinstance(other, SparseVector):
131136
return other.squared_distance(self)
132137
elif _have_scipy and scipy.sparse.issparse(other):
@@ -165,12 +170,10 @@ def __getattr__(self, item):
165170

166171

167172
class SparseVector(Vector):
168-
169173
"""
170174
A simple sparse vector class for passing data to MLlib. Users may
171175
alternatively pass SciPy's {scipy.sparse} data types.
172176
"""
173-
174177
def __init__(self, size, *args):
175178
"""
176179
Create a sparse vector, using either a dictionary, a list of
@@ -223,6 +226,7 @@ def dot(self, other):
223226
>>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))
224227
array([ 22., 22.])
225228
"""
229+
assert len(self) == len(other), "vector sizes mismatch"
226230
if type(other) == np.ndarray:
227231
if other.ndim == 1:
228232
result = 0.0
@@ -348,7 +352,6 @@ def __eq__(self, other):
348352
>>> v1 != v2
349353
False
350354
"""
351-
352355
return (isinstance(other, self.__class__)
353356
and other.size == self.size
354357
and other.indices == self.indices
@@ -414,23 +417,32 @@ def stringify(vector):
414417

415418

416419
class Matrix(object):
417-
""" the Matrix """
418-
def __init__(self, nRow, nCol):
419-
self.nRow = nRow
420-
self.nCol = nCol
420+
"""
421+
Represents a local matrix.
422+
"""
423+
424+
def __init__(self, numRows, numCols):
425+
self.numRows = numRows
426+
self.numCols = numCols
421427

422428
def toArray(self):
429+
"""
430+
Returns its elements in a NumPy ndarray.
431+
"""
423432
raise NotImplementedError
424433

425434

426435
class DenseMatrix(Matrix):
427-
def __init__(self, nRow, nCol, values):
428-
Matrix.__init__(self, nRow, nCol)
429-
assert len(values) == nRow * nCol
436+
"""
437+
Column-majored dense matrix.
438+
"""
439+
def __init__(self, numRows, numCols, values):
440+
Matrix.__init__(self, numRows, numCols)
441+
assert len(values) == numRows * numCols
430442
self.values = values
431443

432444
def __reduce__(self):
433-
return DenseMatrix, (self.nRow, self.nCol, self.values)
445+
return DenseMatrix, (self.numRows, self.numCols, self.values)
434446

435447
def toArray(self):
436448
"""
@@ -439,10 +451,11 @@ def toArray(self):
439451
>>> arr = array.array('d', [float(i) for i in range(4)])
440452
>>> m = DenseMatrix(2, 2, arr)
441453
>>> m.toArray()
442-
array([[ 0., 1.],
443-
[ 2., 3.]])
454+
array([[ 0., 2.],
455+
[ 1., 3.]])
444456
"""
445-
return np.ndarray((self.nRow, self.nCol), np.float64, buffer=self.values.tostring())
457+
return np.ndarray((self.numRows, self.numCols), np.float64,
458+
order='F', buffer=self.values.tostring())
446459

447460

448461
def _test():

0 commit comments

Comments
 (0)