Skip to content

Commit cd6c78f

Browse files
committed
add __str__ and parse to LabeledPoint
1 parent a7a178e commit cd6c78f

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

python/pyspark/mllib/linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def dense(elements):
233233
"""
234234
return array(elements, dtype=float64)
235235

236+
236237
@staticmethod
237238
def parse(s):
238239
"""
@@ -245,6 +246,7 @@ def parse(s):
245246
"""
246247
return Vectors._parse_structured(eval(s))
247248

249+
248250
@staticmethod
249251
def _parse_structured(data):
250252
if type(data) == list:
@@ -254,6 +256,7 @@ def _parse_structured(data):
254256
else:
255257
raise SyntaxError("Cannot recognize " + data)
256258

259+
257260
@staticmethod
258261
def stringify(vector):
259262
"""

python/pyspark/mllib/regression.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_serialize_double_vector, _deserialize_double_vector, \
2424
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \
2525
_linear_predictor_typecheck, _have_scipy, _scipy_issparse
26-
from pyspark.mllib.linalg import SparseVector
26+
from pyspark.mllib.linalg import SparseVector, Vectors
2727

2828

2929
class LabeledPoint(object):
@@ -44,6 +44,26 @@ def __init__(self, label, features):
4444
else:
4545
raise TypeError("Expected NumPy array, list, SparseVector, or scipy.sparse matrix")
4646

47+
def __str__(self):
48+
return "(" + ",".join((str(self.label), Vectors.stringify(self.features))) + ")"
49+
50+
51+
@staticmethod
52+
def parse(s):
53+
"""
54+
Parses a string resulted from str() to a LabeledPoint.
55+
56+
>>> print LabeledPoint.parse("(1.0,[0.0,1.0])")
57+
(1.0,[0.0,1.0])
58+
>>> print LabeledPoint.parse("(1.0,(2,[1],[1.0]))")
59+
(1.0,(2,[1],[1.0]))
60+
"""
61+
return LabeledPoint._parse_structured(eval(s))
62+
63+
64+
@staticmethod
65+
def _parse_structured(data):
66+
return LabeledPoint(data[0], Vectors._parse_structured(data[1]))
4767

4868
class LinearModel(object):
4969
"""A linear model that has a vector of coefficients and an intercept."""

0 commit comments

Comments
 (0)