Skip to content

Commit aec3402

Browse files
MechCoderjeanlyn
authored andcommitted
[SPARK-7328] [MLLIB] [PYSPARK] Pyspark.mllib.linalg.Vectors: Missing items
Add 1. Class methods squared_dist 3. parse 4. norm 5. numNonzeros 6. copy I made a few vectorizations wrt squared_dist and dot as well. I have added support for SparseMatrix serialization in a separate PR (apache#5775) and plan to complete support for Matrices in another PR. Author: MechCoder <[email protected]> Closes apache#5872 from MechCoder/local_linalg_api and squashes the following commits: a8ff1e0 [MechCoder] minor ce3e53e [MechCoder] Add error message for parser 1bd3c04 [MechCoder] Robust parser and removed unnecessary methods f779561 [MechCoder] [SPARK-7328] Pyspark.mllib.linalg.Vectors: Missing items
1 parent 23a927b commit aec3402

File tree

2 files changed

+171
-2
lines changed

2 files changed

+171
-2
lines changed

python/pyspark/mllib/linalg.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,46 @@ def __init__(self, ar):
208208
ar = ar.astype(np.float64)
209209
self.array = ar
210210

211+
@staticmethod
212+
def parse(s):
213+
"""
214+
Parse string representation back into the DenseVector.
215+
216+
>>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]')
217+
DenseVector([0.0, 1.0, 2.0, 3.0])
218+
"""
219+
start = s.find('[')
220+
if start == -1:
221+
raise ValueError("Array should start with '['.")
222+
end = s.find(']')
223+
if end == -1:
224+
raise ValueError("Array should end with ']'.")
225+
s = s[start + 1: end]
226+
227+
try:
228+
values = [float(val) for val in s.split(',')]
229+
except ValueError:
230+
raise ValueError("Unable to parse values from %s" % s)
231+
return DenseVector(values)
232+
211233
def __reduce__(self):
212234
return DenseVector, (self.array.tostring(),)
213235

236+
def numNonzeros(self):
237+
return np.count_nonzero(self.array)
238+
239+
def norm(self, p):
240+
"""
241+
Calculte the norm of a DenseVector.
242+
243+
>>> a = DenseVector([0, -1, 2, -3])
244+
>>> a.norm(2)
245+
3.7...
246+
>>> a.norm(1)
247+
6.0
248+
"""
249+
return np.linalg.norm(self.array, p)
250+
214251
def dot(self, other):
215252
"""
216253
Compute the dot product of two Vectors. We support
@@ -387,8 +424,74 @@ def __init__(self, size, *args):
387424
if self.indices[i] >= self.indices[i + 1]:
388425
raise TypeError("indices array must be sorted")
389426

427+
def numNonzeros(self):
428+
return np.count_nonzero(self.values)
429+
430+
def norm(self, p):
431+
"""
432+
Calculte the norm of a SparseVector.
433+
434+
>>> a = SparseVector(4, [0, 1], [3., -4.])
435+
>>> a.norm(1)
436+
7.0
437+
>>> a.norm(2)
438+
5.0
439+
"""
440+
return np.linalg.norm(self.values, p)
441+
390442
def __reduce__(self):
391-
return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
443+
return (
444+
SparseVector,
445+
(self.size, self.indices.tostring(), self.values.tostring()))
446+
447+
@staticmethod
448+
def parse(s):
449+
"""
450+
Parse string representation back into the DenseVector.
451+
452+
>>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )')
453+
SparseVector(4, {0: 4.0, 1: 5.0})
454+
"""
455+
start = s.find('(')
456+
if start == -1:
457+
raise ValueError("Tuple should start with '('")
458+
end = s.find(')')
459+
if start == -1:
460+
raise ValueError("Tuple should end with ')'")
461+
s = s[start + 1: end].strip()
462+
463+
size = s[: s.find(',')]
464+
try:
465+
size = int(size)
466+
except ValueError:
467+
raise ValueError("Cannot parse size %s." % size)
468+
469+
ind_start = s.find('[')
470+
if ind_start == -1:
471+
raise ValueError("Indices array should start with '['.")
472+
ind_end = s.find(']')
473+
if ind_end == -1:
474+
raise ValueError("Indices array should end with ']'")
475+
new_s = s[ind_start + 1: ind_end]
476+
ind_list = new_s.split(',')
477+
try:
478+
indices = [int(ind) for ind in ind_list]
479+
except ValueError:
480+
raise ValueError("Unable to parse indices from %s." % new_s)
481+
s = s[ind_end + 1:].strip()
482+
483+
val_start = s.find('[')
484+
if val_start == -1:
485+
raise ValueError("Values array should start with '['.")
486+
val_end = s.find(']')
487+
if val_end == -1:
488+
raise ValueError("Values array should end with ']'.")
489+
val_list = s[val_start + 1: val_end].split(',')
490+
try:
491+
values = [float(val) for val in val_list]
492+
except ValueError:
493+
raise ValueError("Unable to parse values from %s." % s)
494+
return SparseVector(size, indices, values)
392495

393496
def dot(self, other):
394497
"""
@@ -633,6 +736,49 @@ def stringify(vector):
633736
"""
634737
return str(vector)
635738

739+
@staticmethod
740+
def squared_distance(v1, v2):
741+
"""
742+
Squared distance between two vectors.
743+
a and b can be of type SparseVector, DenseVector, np.ndarray
744+
or array.array.
745+
746+
>>> a = Vectors.sparse(4, [(0, 1), (3, 4)])
747+
>>> b = Vectors.dense([2, 5, 4, 1])
748+
>>> a.squared_distance(b)
749+
51.0
750+
"""
751+
v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2)
752+
return v1.squared_distance(v2)
753+
754+
@staticmethod
755+
def norm(vector, p):
756+
"""
757+
Find norm of the given vector.
758+
"""
759+
return _convert_to_vector(vector).norm(p)
760+
761+
@staticmethod
762+
def parse(s):
763+
"""Parse a string representation back into the Vector.
764+
765+
>>> Vectors.parse('[2,1,2 ]')
766+
DenseVector([2.0, 1.0, 2.0])
767+
>>> Vectors.parse(' ( 100, [0], [2])')
768+
SparseVector(100, {0: 2.0})
769+
"""
770+
if s.find('(') == -1 and s.find('[') != -1:
771+
return DenseVector.parse(s)
772+
elif s.find('(') != -1:
773+
return SparseVector.parse(s)
774+
else:
775+
raise ValueError(
776+
"Cannot find tokens '[' or '(' from the input string.")
777+
778+
@staticmethod
779+
def zeros(size):
780+
return DenseVector(np.zeros(size))
781+
636782

637783
class Matrix(object):
638784
"""

python/pyspark/mllib/tests.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import tempfile
2525
import array as pyarray
2626

27-
from numpy import array, array_equal, zeros
27+
from numpy import array, array_equal, zeros, inf
2828
from py4j.protocol import Py4JJavaError
2929

3030
if sys.version_info[:2] <= (2, 6):
@@ -220,6 +220,29 @@ def test_dense_matrix_is_transposed(self):
220220
self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5]))
221221
self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9]))
222222

223+
def test_parse_vector(self):
224+
a = DenseVector([3, 4, 6, 7])
225+
self.assertTrue(str(a), '[3.0,4.0,6.0,7.0]')
226+
self.assertTrue(Vectors.parse(str(a)), a)
227+
a = SparseVector(4, [0, 2], [3, 4])
228+
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
229+
self.assertTrue(Vectors.parse(str(a)), a)
230+
a = SparseVector(10, [0, 1], [4, 5])
231+
self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
232+
233+
def test_norms(self):
234+
a = DenseVector([0, 2, 3, -1])
235+
self.assertAlmostEqual(a.norm(2), 3.742, 3)
236+
self.assertTrue(a.norm(1), 6)
237+
self.assertTrue(a.norm(inf), 3)
238+
a = SparseVector(4, [0, 2], [3, -4])
239+
self.assertAlmostEqual(a.norm(2), 5)
240+
self.assertTrue(a.norm(1), 7)
241+
self.assertTrue(a.norm(inf), 4)
242+
243+
tmp = SparseVector(4, [0, 2], [3, 0])
244+
self.assertEqual(tmp.numNonzeros(), 1)
245+
223246

224247
class ListTests(MLlibTestCase):
225248

0 commit comments

Comments
 (0)