@@ -63,6 +63,41 @@ def _convert_to_vector(l):
6363 raise TypeError ("Cannot convert type %s into Vector" % type (l ))
6464
6565
66+ def _vector_size (v ):
67+ """
68+ Returns the size of the vector.
69+
70+ >>> _vector_size([1., 2., 3.])
71+ 3
72+ >>> _vector_size((1., 2., 3.))
73+ 3
74+ >>> _vector_size(array.array('d', [1., 2., 3.]))
75+ 3
76+ >>> _vector_size(np.zeros(3))
77+ 3
78+ >>> _vector_size(np.zeros((3, 1)))
79+ 3
80+ >>> _vector_size(np.zeros((1, 3)))
81+ Traceback (most recent call last):
82+ ...
83+ ValueError: Cannot treat an ndarray of shape (1, 3) as a vector
84+ """
85+ if isinstance (v , Vector ):
86+ return len (v )
87+ elif type (v ) in (array .array , list , tuple ):
88+ return len (v )
89+ elif type (v ) == np .ndarray :
90+ if v .ndim == 1 or (v .ndim == 2 and v .shape [1 ] == 1 ):
91+ return len (v )
92+ else :
93+ raise ValueError ("Cannot treat an ndarray of shape %s as a vector" % str (v .shape ))
94+ elif _have_scipy and scipy .sparse .issparse (v ):
95+ assert v .shape [1 ] == 1 , "Expected column vector"
96+ return v .shape [0 ]
97+ else :
98+ raise TypeError ("Cannot treat type %s as a vector" % type (v ))
99+
100+
66101class Vector (object ):
67102 """
68103 Abstract class for DenseVector and SparseVector
@@ -76,6 +111,9 @@ def toArray(self):
76111
77112
78113class DenseVector (Vector ):
114+ """
115+ A dense vector represented by a value array.
116+ """
79117 def __init__ (self , ar ):
80118 if not isinstance (ar , array .array ):
81119 ar = array .array ('d' , ar )
@@ -100,15 +138,31 @@ def dot(self, other):
100138 5.0
101139 >>> dense.dot(np.array(range(1, 3)))
102140 5.0
141+ >>> dense.dot([1.,])
142+ Traceback (most recent call last):
143+ ...
144+ AssertionError: dimension mismatch
145+ >>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F'))
146+ array([ 5., 11.])
147+ >>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F'))
148+ Traceback (most recent call last):
149+ ...
150+ AssertionError: dimension mismatch
103151 """
104- if isinstance (other , SparseVector ):
105- return other .dot (self )
152+ if type (other ) == np .ndarray and other .ndim > 1 :
153+ assert len (self ) == other .shape [0 ], "dimension mismatch"
154+ return np .dot (self .toArray (), other )
106155 elif _have_scipy and scipy .sparse .issparse (other ):
107- return other .transpose ().dot (self .toArray ())[0 ]
108- elif isinstance (other , Vector ):
109- return np .dot (self .toArray (), other .toArray ())
156+ assert len (self ) == other .shape [0 ], "dimension mismatch"
157+ return other .transpose ().dot (self .toArray ())
110158 else :
111- return np .dot (self .toArray (), other )
159+ assert len (self ) == _vector_size (other ), "dimension mismatch"
160+ if isinstance (other , SparseVector ):
161+ return other .dot (self )
162+ elif isinstance (other , Vector ):
163+ return np .dot (self .toArray (), other .toArray ())
164+ else :
165+ return np .dot (self .toArray (), other )
112166
113167 def squared_distance (self , other ):
114168 """
@@ -126,7 +180,16 @@ def squared_distance(self, other):
126180 >>> sparse1 = SparseVector(2, [0, 1], [2., 1.])
127181 >>> dense1.squared_distance(sparse1)
128182 2.0
183+ >>> dense1.squared_distance([1.,])
184+ Traceback (most recent call last):
185+ ...
186+ AssertionError: dimension mismatch
187+ >>> dense1.squared_distance(SparseVector(1, [0,], [1.,]))
188+ Traceback (most recent call last):
189+ ...
190+ AssertionError: dimension mismatch
129191 """
192+ assert len (self ) == _vector_size (other ), "dimension mismatch"
130193 if isinstance (other , SparseVector ):
131194 return other .squared_distance (self )
132195 elif _have_scipy and scipy .sparse .issparse (other ):
@@ -165,12 +228,10 @@ def __getattr__(self, item):
165228
166229
167230class SparseVector (Vector ):
168-
169231 """
170232 A simple sparse vector class for passing data to MLlib. Users may
171233 alternatively pass SciPy's {scipy.sparse} data types.
172234 """
173-
174235 def __init__ (self , size , * args ):
175236 """
176237 Create a sparse vector, using either a dictionary, a list of
@@ -222,20 +283,33 @@ def dot(self, other):
222283 0.0
223284 >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]))
224285 array([ 22., 22.])
286+ >>> a.dot([1., 2., 3.])
287+ Traceback (most recent call last):
288+ ...
289+ AssertionError: dimension mismatch
290+ >>> a.dot(np.array([1., 2.]))
291+ Traceback (most recent call last):
292+ ...
293+ AssertionError: dimension mismatch
294+ >>> a.dot(DenseVector([1., 2.]))
295+ Traceback (most recent call last):
296+ ...
297+ AssertionError: dimension mismatch
298+ >>> a.dot(np.zeros((3, 2)))
299+ Traceback (most recent call last):
300+ ...
301+ AssertionError: dimension mismatch
225302 """
226303 if type (other ) == np .ndarray :
227- if other .ndim == 1 :
228- result = 0.0
229- for i in xrange (len (self .indices )):
230- result += self .values [i ] * other [self .indices [i ]]
231- return result
232- elif other .ndim == 2 :
304+ if other .ndim == 2 :
233305 results = [self .dot (other [:, i ]) for i in xrange (other .shape [1 ])]
234306 return np .array (results )
235- else :
236- raise Exception ("Cannot call dot with %d-dimensional array" % other .ndim )
307+ elif other .ndim > 2 :
308+ raise ValueError ("Cannot call dot with %d-dimensional array" % other .ndim )
309+
310+ assert len (self ) == _vector_size (other ), "dimension mismatch"
237311
238- elif type (other ) in (array .array , DenseVector ):
312+ if type (other ) in (np . ndarray , array .array , DenseVector ):
239313 result = 0.0
240314 for i in xrange (len (self .indices )):
241315 result += self .values [i ] * other [self .indices [i ]]
@@ -254,6 +328,7 @@ def dot(self, other):
254328 else :
255329 j += 1
256330 return result
331+
257332 else :
258333 return self .dot (_convert_to_vector (other ))
259334
@@ -273,7 +348,16 @@ def squared_distance(self, other):
273348 30.0
274349 >>> b.squared_distance(a)
275350 30.0
351+ >>> b.squared_distance([1., 2.])
352+ Traceback (most recent call last):
353+ ...
354+ AssertionError: dimension mismatch
355+ >>> b.squared_distance(SparseVector(3, [1,], [1.0,]))
356+ Traceback (most recent call last):
357+ ...
358+ AssertionError: dimension mismatch
276359 """
360+ assert len (self ) == _vector_size (other ), "dimension mismatch"
277361 if type (other ) in (list , array .array , DenseVector , np .array , np .ndarray ):
278362 if type (other ) is np .array and other .ndim != 1 :
279363 raise Exception ("Cannot call squared_distance with %d-dimensional array" %
@@ -348,7 +432,6 @@ def __eq__(self, other):
348432 >>> v1 != v2
349433 False
350434 """
351-
352435 return (isinstance (other , self .__class__ )
353436 and other .size == self .size
354437 and other .indices == self .indices
@@ -414,23 +497,32 @@ def stringify(vector):
414497
415498
416499class Matrix (object ):
417- """ the Matrix """
418- def __init__ (self , nRow , nCol ):
419- self .nRow = nRow
420- self .nCol = nCol
500+ """
501+ Represents a local matrix.
502+ """
503+
504+ def __init__ (self , numRows , numCols ):
505+ self .numRows = numRows
506+ self .numCols = numCols
421507
422508 def toArray (self ):
509+ """
510+ Returns its elements in a NumPy ndarray.
511+ """
423512 raise NotImplementedError
424513
425514
426515class DenseMatrix (Matrix ):
427- def __init__ (self , nRow , nCol , values ):
428- Matrix .__init__ (self , nRow , nCol )
429- assert len (values ) == nRow * nCol
516+ """
517+ Column-major dense matrix.
518+ """
519+ def __init__ (self , numRows , numCols , values ):
520+ Matrix .__init__ (self , numRows , numCols )
521+ assert len (values ) == numRows * numCols
430522 self .values = values
431523
432524 def __reduce__ (self ):
433- return DenseMatrix , (self .nRow , self .nCol , self .values )
525+ return DenseMatrix , (self .numRows , self .numCols , self .values )
434526
435527 def toArray (self ):
436528 """
@@ -439,10 +531,10 @@ def toArray(self):
439531 >>> arr = array.array('d', [float(i) for i in range(4)])
440532 >>> m = DenseMatrix(2, 2, arr)
441533 >>> m.toArray()
442- array([[ 0., 1 .],
443- [ 2 ., 3.]])
534+ array([[ 0., 2 .],
535+ [ 1 ., 3.]])
444536 """
445- return np .ndarray (( self .nRow , self .nCol ), np . float64 , buffer = self . values . tostring () )
537+ return np .reshape ( self .values , ( self .numRows , self . numCols ), order = 'F' )
446538
447539
448540def _test ():
0 commit comments