Skip to content

Commit 1bd3c04

Browse files
committed
Robust parser and removed unnecessary methods
1 parent f779561 commit 1bd3c04

File tree

2 files changed

+40
-103
lines changed

2 files changed

+40
-103
lines changed

python/pyspark/mllib/linalg.py

Lines changed: 35 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import sys
2727
import array
28-
from math import sqrt
2928

3029
if sys.version >= '3':
3130
basestring = str
@@ -209,35 +208,24 @@ def __init__(self, ar):
209208
ar = ar.astype(np.float64)
210209
self.array = ar
211210

212-
def toString(self):
213-
"""
214-
Convert DenseVector to string representation.
215-
216-
>>> a = DenseVector([0, 1, 2, 3])
217-
>>> a.toString()
218-
'[0.0,1.0,2.0,3.0]'
219-
"""
220-
return str(self)
221-
222-
def copy(self):
223-
return DenseVector(np.copy(self.array))
224-
225211
@staticmethod
226212
def parse(vectorString):
227213
"""
228214
Parse string representation back into the DenseVector.
229215
230-
>>> DenseVector.parse('[0.0,1.0,2.0,3.0]')
216+
>>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]')
231217
DenseVector([0.0, 1.0, 2.0, 3.0])
232218
"""
233-
vectorString = vectorString[1:-1]
219+
start = vectorString.find('[')
220+
end = vectorString.find(']')
221+
vectorString = vectorString[start + 1: end]
234222
return DenseVector([float(val) for val in vectorString.split(',')])
235223

236224
def __reduce__(self):
237225
return DenseVector, (self.array.tostring(),)
238226

239227
def numNonzeros(self):
240-
return np.nonzero(self.array)[0].size
228+
return np.count_nonzero(self.array)
241229

242230
def norm(self, p):
243231
"""
@@ -249,14 +237,7 @@ def norm(self, p):
249237
>>> a.norm(1)
250238
6.0
251239
"""
252-
if p == 1:
253-
return np.sum(np.abs(self.array))
254-
elif p == 2:
255-
return sqrt(np.dot(self.array, self.array))
256-
elif p == np.inf:
257-
return np.max(np.abs(self.array))
258-
else:
259-
return pow(np.power(self.array, p), 1.0 / p)
240+
return np.linalg.norm(self.array, p)
260241

261242
def dot(self, other):
262243
"""
@@ -434,11 +415,8 @@ def __init__(self, size, *args):
434415
if self.indices[i] >= self.indices[i + 1]:
435416
raise TypeError("indices array must be sorted")
436417

437-
def copy(self):
438-
return SparseVector(self.size, np.copy(self.indices), np.copy(self.values))
439-
440418
def numNonzeros(self):
441-
return np.nonzero(self.values)[0].size
419+
return np.count_nonzero(self.values)
442420

443421
def norm(self, p):
444422
"""
@@ -450,42 +428,36 @@ def norm(self, p):
450428
>>> a.norm(2)
451429
5.0
452430
"""
453-
if p == 1:
454-
return np.sum(np.abs(self.values))
455-
elif p == 2:
456-
return sqrt(np.dot(self.values, self.values))
457-
elif p == np.inf:
458-
return np.max(np.abs(self.values))
459-
else:
460-
return pow(np.power(self.values, p), 1.0 / p)
431+
return np.linalg.norm(self.values, p)
461432

462433
def __reduce__(self):
463-
return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
464-
465-
def toString(self):
466-
"""
467-
Convert SparseVector to string representation.
468-
469-
>>> a = SparseVector(4, [0, 1], [4, 5])
470-
>>> a.toString()
471-
'(4,[0,1],[4.0,5.0])'
472-
"""
473-
return str(self)
434+
return (
435+
SparseVector,
436+
(self.size, self.indices.tostring(), self.values.tostring()))
474437

475438
@staticmethod
476439
def parse(vectorString):
477440
"""
478441
Parse string representation back into the DenseVector.
479442
480-
>>> SparseVector.parse('(4,[0,1],[4.0,5.0])')
443+
>>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )')
481444
SparseVector(4, {0: 4.0, 1: 5.0})
482445
"""
483-
size = int(vectorString[1])
446+
start = vectorString.find('(')
447+
end = vectorString.find(')')
448+
vectorString = vectorString[start+1: end].strip()
449+
size = int(vectorString[0])
450+
451+
ind_start = vectorString.find('[')
484452
ind_end = vectorString.find(']')
485-
index_string = vectorString[4: ind_end]
486-
indices = [int(ind) for ind in index_string.split(',')]
487-
value_string = vectorString[ind_end + 3: -2]
488-
values = [float(val) for val in value_string.split(',')]
453+
ind_list = vectorString[ind_start + 1: ind_end].split(',')
454+
indices = [int(ind) for ind in ind_list]
455+
vectorString = vectorString[ind_end + 1:].strip()
456+
457+
val_start = vectorString.find('[')
458+
val_end = vectorString.find(']')
459+
val_list = vectorString[val_start + 1: val_end].split(',')
460+
values = [float(val) for val in val_list]
489461
return SparseVector(size, indices, values)
490462

491463
def dot(self, other):
@@ -528,15 +500,12 @@ def dot(self, other):
528500

529501
assert len(self) == _vector_size(other), "dimension mismatch"
530502

531-
if type(other) in (np.ndarray, array.array):
503+
if type(other) in (np.ndarray, array.array, DenseVector):
532504
result = 0.0
533-
for i, ind in enumerate(self.indices):
534-
result += self.values[i] * other[ind]
505+
for i in xrange(len(self.indices)):
506+
result += self.values[i] * other[self.indices[i]]
535507
return result
536508

537-
elif isinstance(other, DenseVector):
538-
return np.dot(other.toArray()[self.indices], self.values)
539-
540509
elif type(other) is SparseVector:
541510
result = 0.0
542511
i, j = 0, 0
@@ -580,28 +549,19 @@ def squared_distance(self, other):
580549
AssertionError: dimension mismatch
581550
"""
582551
assert len(self) == _vector_size(other), "dimension mismatch"
583-
if type(other) in (list, array.array, np.array, np.ndarray):
552+
if type(other) in (list, array.array, DenseVector, np.array, np.ndarray):
584553
if type(other) is np.array and other.ndim != 1:
585554
raise Exception("Cannot call squared_distance with %d-dimensional array" %
586555
other.ndim)
587556
result = 0.0
588557
j = 0 # index into our own array
589-
for i, other_ind in enumerate(other):
558+
for i in xrange(len(other)):
590559
if j < len(self.indices) and self.indices[j] == i:
591-
diff = self.values[j] - other_ind
560+
diff = self.values[j] - other[i]
592561
result += diff * diff
593562
j += 1
594563
else:
595-
result += other_ind * other_ind
596-
return result
597-
598-
elif isinstance(other, DenseVector):
599-
bool_ind = np.zeros(len(other), dtype=bool)
600-
bool_ind[self.indices] = True
601-
dist = other.toArray()[bool_ind] - self.values
602-
result = np.dot(dist, dist)
603-
other_values = other.toArray()[~bool_ind]
604-
result += np.dot(other_values, other_values)
564+
result += other[i] * other[i]
605565
return result
606566

607567
elif type(other) is SparseVector:
@@ -743,30 +703,11 @@ def stringify(vector):
743703
"""
744704
return str(vector)
745705

746-
@staticmethod
747-
def dot(a, b):
748-
"""
749-
Dot product between two vectors.
750-
a and b can be of type, SparseVector, DenseVector, np.ndarray
751-
or array.array.
752-
753-
>>> a = Vectors.sparse(4, [(0, 1), (3, 4)])
754-
>>> b = Vectors.dense([23, 41, 9, 1])
755-
>>> Vectors.dot(a, b)
756-
27.0
757-
>>> Vectors.dot(a, a)
758-
17.0
759-
>>> Vectors.dot(a, np.array([0, 1, 2, 4]))
760-
16.0
761-
"""
762-
a, b = _convert_to_vector(a), _convert_to_vector(b)
763-
return a.dot(b)
764-
765706
@staticmethod
766707
def squared_distance(a, b):
767708
"""
768709
Squared distance between two vectors.
769-
a and b can be of type, SparseVector, DenseVector, np.ndarray
710+
a and b can be of type SparseVector, DenseVector, np.ndarray
770711
or array.array.
771712
772713
>>> a = Vectors.sparse(4, [(0, 1), (3, 4)])
@@ -786,7 +727,7 @@ def norm(vec, p):
786727

787728
@staticmethod
788729
def parse(vectorString):
789-
if vectorString[0] == '[':
730+
if vectorString.find('(') == -1:
790731
return DenseVector.parse(vectorString)
791732
return SparseVector.parse(vectorString)
792733

python/pyspark/mllib/tests.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,6 @@ def test_dot(self):
110110
self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat)))
111111
self.assertEquals(30.0, lst.dot(dv))
112112
self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat)))
113-
self.assertEquals(Vectors.dot(sv, sv), 5.)
114-
self.assertEquals(Vectors.dot(sv, dv), 10.)
115-
self.assertEquals(Vectors.dot(dv, sv), 10.)
116-
self.assertEquals(Vectors.dot(sv, array([2, 5, 7, 8])), 21.0)
117113

118114
def test_squared_distance(self):
119115
sv = SparseVector(4, {1: 1, 3: 2})
@@ -224,13 +220,13 @@ def test_dense_matrix_is_transposed(self):
224220
self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
225221
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
226222

227-
def test_parse_matrix(self):
223+
def test_parse_vector(self):
228224
a = DenseVector([3, 4, 6, 7])
229-
self.assertTrue(a.toString(), '[3.0,4.0,6.0,7.0]')
230-
self.assertTrue(Vectors.parse(a.toString()), a)
225+
self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
226+
self.assertTrue(Vectors.parse(str(a)), a)
231227
a = SparseVector(4, [0, 2], [3, 4])
232-
self.assertTrue(a.toString(), '(4,[0,2],[3.0,4.0])')
233-
self.assertTrue(Vectors.parse(a.toString()), a)
228+
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
229+
self.assertTrue(Vectors.parse(str(a)), a)
234230

235231
def test_norms(self):
236232
a = DenseVector([0, 2, 3, -1])

0 commit comments

Comments
 (0)