Skip to content

Commit ce3e53e

Browse files
committed
Add error message for parser
1 parent 1bd3c04 commit ce3e53e

File tree

2 files changed

+77
-32
lines changed

2 files changed

+77
-32
lines changed

python/pyspark/mllib/linalg.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -209,17 +209,26 @@ def __init__(self, ar):
209209
self.array = ar
210210

211211
@staticmethod
212-
def parse(vectorString):
212+
def parse(s):
213213
"""
214214
Parse string representation back into the DenseVector.
215215
216216
>>> DenseVector.parse(' [ 0.0,1.0,2.0, 3.0]')
217217
DenseVector([0.0, 1.0, 2.0, 3.0])
218218
"""
219-
start = vectorString.find('[')
220-
end = vectorString.find(']')
221-
vectorString = vectorString[start + 1: end]
222-
return DenseVector([float(val) for val in vectorString.split(',')])
219+
start = s.find('[')
220+
if start == -1:
221+
raise ValueError("Array should start with '['")
222+
end = s.find(']')
223+
if end == -1:
224+
raise ValueError("Array should end with ']")
225+
s = s[start + 1: end]
226+
227+
try:
228+
values = [float(val) for val in s.split(',')]
229+
except ValueError:
230+
raise ValueError("Unable to parse values.")
231+
return DenseVector(values)
223232

224233
def __reduce__(self):
225234
return DenseVector, (self.array.tostring(),)
@@ -436,28 +445,51 @@ def __reduce__(self):
436445
(self.size, self.indices.tostring(), self.values.tostring()))
437446

438447
@staticmethod
439-
def parse(vectorString):
448+
def parse(s):
440449
"""
441450
Parse string representation back into the DenseVector.
442451
443452
>>> SparseVector.parse(' (4, [0,1 ],[ 4.0,5.0] )')
444453
SparseVector(4, {0: 4.0, 1: 5.0})
445454
"""
446-
start = vectorString.find('(')
447-
end = vectorString.find(')')
448-
vectorString = vectorString[start+1: end].strip()
449-
size = int(vectorString[0])
450-
451-
ind_start = vectorString.find('[')
452-
ind_end = vectorString.find(']')
453-
ind_list = vectorString[ind_start + 1: ind_end].split(',')
454-
indices = [int(ind) for ind in ind_list]
455-
vectorString = vectorString[ind_end + 1:].strip()
456-
457-
val_start = vectorString.find('[')
458-
val_end = vectorString.find(']')
459-
val_list = vectorString[val_start + 1: val_end].split(',')
460-
values = [float(val) for val in val_list]
455+
start = s.find('(')
456+
if start == -1:
457+
raise ValueError("Tuple should start with '('")
458+
end = s.find(')')
459+
if start == -1:
460+
raise ValueError("Tuple should end with ')'")
461+
s = s[start + 1: end].strip()
462+
463+
size = s[: s.find(',')]
464+
try:
465+
size = int(size)
466+
except ValueError:
467+
raise ValueError("Cannot parse size %s." % size)
468+
469+
ind_start = s.find('[')
470+
if ind_start == -1:
471+
raise ValueError("Indices array should start with '('.")
472+
ind_end = s.find(']')
473+
if ind_end == -1:
474+
raise ValueError("Indices array should end with ')'")
475+
ind_list = s[ind_start + 1: ind_end].split(',')
476+
try:
477+
indices = [int(ind) for ind in ind_list]
478+
except ValueError:
479+
raise ValueError("Unabel to parse indices.")
480+
s = s[ind_end + 1:].strip()
481+
482+
val_start = s.find('[')
483+
if val_start == -1:
484+
raise ValueError("Values array should start with '('.")
485+
val_end = s.find(']')
486+
if val_end == -1:
487+
raise ValueError("Values array should end with ')'.")
488+
val_list = s[val_start + 1: val_end].split(',')
489+
try:
490+
values = [float(val) for val in val_list]
491+
except ValueError:
492+
raise ValueError("Unable to parse values.")
461493
return SparseVector(size, indices, values)
462494

463495
def dot(self, other):
@@ -704,7 +736,7 @@ def stringify(vector):
704736
return str(vector)
705737

706738
@staticmethod
707-
def squared_distance(a, b):
739+
def squared_distance(v1, v2):
708740
"""
709741
Squared distance between two vectors.
710742
a and b can be of type SparseVector, DenseVector, np.ndarray
@@ -715,25 +747,36 @@ def squared_distance(a, b):
715747
>>> a.squared_distance(b)
716748
51.0
717749
"""
718-
a, b = _convert_to_vector(a), _convert_to_vector(b)
719-
return a.squared_distance(b)
750+
v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2)
751+
return v1.squared_distance(v2)
720752

721753
@staticmethod
722-
def norm(vec, p):
754+
def norm(vector, p):
723755
"""
724756
Find norm of the given vector.
725757
"""
726-
return _convert_to_vector(vec).norm(p)
758+
return _convert_to_vector(vector).norm(p)
727759

728760
@staticmethod
729-
def parse(vectorString):
730-
if vectorString.find('(') == -1:
731-
return DenseVector.parse(vectorString)
732-
return SparseVector.parse(vectorString)
761+
def parse(s):
762+
"""Parse a string representation back into the Vector.
763+
764+
>>> Vectors.parse('[2,1,2 ]')
765+
DenseVector([2.0, 1.0, 2.0])
766+
>>> Vectors.parse(' ( 100, [0], [2])')
767+
SparseVector(100, {0: 2.0})
768+
"""
769+
if s.find('(') == -1 and s.find('[') != -1:
770+
return DenseVector.parse(s)
771+
elif s.find('(') != -1:
772+
return SparseVector.parse(s)
773+
else:
774+
raise ValueError(
775+
"Cannot find tokens '[' or '(' from the input string.")
733776

734777
@staticmethod
735-
def zeros(num):
736-
return DenseVector(np.zeros(num))
778+
def zeros(size):
779+
return DenseVector(np.zeros(size))
737780

738781

739782
class Matrix(object):

python/pyspark/mllib/tests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ def test_parse_vector(self):
227227
a = SparseVector(4, [0, 2], [3, 4])
228228
self.assertTrue(str(a), '(4,[0,2],[3.0,4.0])')
229229
self.assertTrue(Vectors.parse(str(a)), a)
230+
a = SparseVector(10, [0, 1], [4, 5])
231+
self.assertTrue(SparseVector.parse(' (10, [0,1 ],[ 4.0,5.0] )'), a)
230232

231233
def test_norms(self):
232234
a = DenseVector([0, 2, 3, -1])

0 commit comments

Comments
 (0)