Skip to content

Commit 63dc396

Browse files
committed
add loadLabeledPoints to pyspark
1 parent ea122b5 commit 63dc396

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

python/pyspark/mllib/util.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,24 +106,18 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non
106106
>>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
107107
>>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect()
108108
>>> tempFile.close()
109-
>>> examples[0].label
110-
1.0
111-
>>> examples[0].features.size
112-
6
113-
>>> print examples[0].features
114-
[0: 1.0, 2: 2.0, 4: 3.0]
115-
>>> examples[1].label
116-
0.0
117-
>>> examples[1].features.size
118-
6
119-
>>> print examples[1].features
120-
[]
121-
>>> examples[2].label
122-
0.0
123-
>>> examples[2].features.size
124-
6
125-
>>> print examples[2].features
126-
[1: 4.0, 3: 5.0, 5: 6.0]
109+
>>> type(examples[0]) == LabeledPoint
110+
True
111+
>>> print examples[0]
112+
(1.0,(6,[0,2,4],[1.0,2.0,3.0]))
113+
>>> type(examples[1]) == LabeledPoint
114+
True
115+
>>> print examples[1]
116+
(0.0,(6,[],[]))
117+
>>> type(examples[2]) == LabeledPoint
118+
True
119+
>>> print examples[2]
120+
(0.0,(6,[1,3,5],[4.0,5.0,6.0]))
127121
>>> multiclass_examples[1].label
128122
-1.0
129123
"""
@@ -160,6 +154,37 @@ def saveAsLibSVMFile(data, dir):
160154
lines.saveAsTextFile(dir)
161155

162156

157+
@staticmethod
158+
def loadLabeledPoints(sc, path, minPartitions=None):
159+
"""
160+
Load labeled points saved using RDD.saveAsTextFile.
161+
162+
@param sc: Spark context
163+
@param path: file or directory path in any Hadoop-supported file
164+
system URI
165+
@param minPartitions: min number of partitions
166+
@return: labeled data stored as an RDD of LabeledPoint
167+
168+
>>> from tempfile import NamedTemporaryFile
169+
>>> from pyspark.mllib.util import MLUtils
170+
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \
171+
LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
172+
>>> tempFile = NamedTemporaryFile(delete=True)
173+
>>> tempFile.close()
174+
>>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name)
175+
>>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect()
176+
>>> type(loaded[0]) == LabeledPoint
177+
True
178+
>>> print examples[0]
179+
(1.1,(3,[0,2],[1.23,4.56]))
180+
>>> type(examples[1]) == LabeledPoint
181+
True
182+
>>> print examples[1]
183+
(0.0,[1.01,2.02,3.03])
184+
185+
"""
186+
return sc.textFile(path, minPartitions).map(LabeledPoint.parse)
187+
163188
def _test():
164189
import doctest
165190
from pyspark.context import SparkContext

0 commit comments

Comments
 (0)