@@ -106,24 +106,18 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non
106106 >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect()
107107 >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect()
108108 >>> tempFile.close()
109- >>> examples[0].label
110- 1.0
111- >>> examples[0].features.size
112- 6
113- >>> print examples[0].features
114- [0: 1.0, 2: 2.0, 4: 3.0]
115- >>> examples[1].label
116- 0.0
117- >>> examples[1].features.size
118- 6
119- >>> print examples[1].features
120- []
121- >>> examples[2].label
122- 0.0
123- >>> examples[2].features.size
124- 6
125- >>> print examples[2].features
126- [1: 4.0, 3: 5.0, 5: 6.0]
109+ >>> type(examples[0]) == LabeledPoint
110+ True
111+ >>> print examples[0]
112+ (1.0,(6,[0,2,4],[1.0,2.0,3.0]))
113+ >>> type(examples[1]) == LabeledPoint
114+ True
115+ >>> print examples[1]
116+ (0.0,(6,[],[]))
117+ >>> type(examples[2]) == LabeledPoint
118+ True
119+ >>> print examples[2]
120+ (0.0,(6,[1,3,5],[4.0,5.0,6.0]))
127121 >>> multiclass_examples[1].label
128122 -1.0
129123 """
@@ -160,6 +154,37 @@ def saveAsLibSVMFile(data, dir):
160154 lines .saveAsTextFile (dir )
161155
162156
157+ @staticmethod
158+ def loadLabeledPoints (sc , path , minPartitions = None ):
159+ """
160+ Load labeled points saved using RDD.saveAsTextFile.
161+
162+ @param sc: Spark context
163+ @param path: file or directory path in any Hadoop-supported file
164+ system URI
165+ @param minPartitions: min number of partitions
166+ @return: labeled data stored as an RDD of LabeledPoint
167+
168+ >>> from tempfile import NamedTemporaryFile
169+ >>> from pyspark.mllib.util import MLUtils
170+ >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \
171+ LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
172+ >>> tempFile = NamedTemporaryFile(delete=True)
173+ >>> tempFile.close()
174+ >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name)
175+ >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect()
176+ >>> type(loaded[0]) == LabeledPoint
177+ True
178+ >>> print examples[0]
179+ (1.1,(3,[0,2],[1.23,4.56]))
180+ >>> type(examples[1]) == LabeledPoint
181+ True
182+ >>> print examples[1]
183+ (0.0,[1.01,2.02,3.03])
184+
185+ """
186+ return sc .textFile (path , minPartitions ).map (LabeledPoint .parse )
187+
163188def _test ():
164189 import doctest
165190 from pyspark .context import SparkContext
0 commit comments