Skip to content

Commit c8b16ca

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-2850] [SPARK-2626] [mllib] MLlib stats examples + small fixes
Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API) Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey Added sc.stop() to all examples. CorrelationSuite.scala * Added 1 test for RDDs with only 1 value RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. Python SparseVector (pyspark/mllib/linalg.py) * Added toDense() function python/run-tests script * Added stat.py (doc test) CC: mengxr dorx Main changes were examples to show usage across APIs. Author: Joseph K. Bradley <[email protected]> Closes apache#1878 from jkbradley/mllib-stats-api-check and squashes the following commits: ea5c047 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check dafebe2 [Joseph K. Bradley] Bug fixes for examples SampledRDDs.scala and sampled_rdds.py: Check for division by 0 and for missing key in maps. 8d1e555 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 60c72d9 [Joseph K. Bradley] Fixed stat.py doc test to work for Python versions printing nan or NaN. b20d90a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 4e5d15e [Joseph K. Bradley] Changed pyspark/mllib/stat.py doc tests to use NaN instead of nan. 32173b7 [Joseph K. Bradley] Stats examples update. c8c20dc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check cf70b07 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 0b7cec3 [Joseph K. Bradley] Small updates based on code review. Renamed statistical_summary.py to correlations.py ab48f6e [Joseph K. Bradley] RowMatrix.scala * numCols(): Added check for numRows = 0, with error message. * computeCovariance(): Added check for numRows <= 1, with error message. 65e4ebc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check 8195c78 [Joseph K. Bradley] Added examples for random and sampled RDDs: * Scala: RandomAndSampledRDDs.scala * python: random_and_sampled_rdds.py * Both test: ** RandomRDDGenerators.normalRDD, normalVectorRDD ** RDD.sample, takeSample, sampleByKey 064985b [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into mllib-stats-api-check ee918e9 [Joseph K. Bradley] Added examples for statistical summarization: * Scala: StatisticalSummary.scala ** Tests: correlation, MultivariateOnlineSummarizer * python: statistical_summary.py ** Tests: correlation (since MultivariateOnlineSummarizer has no Python API)
1 parent 115eeb3 commit c8b16ca

29 files changed

+664
-20
lines changed

examples/src/main/python/als.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,5 @@ def update(i, vec, mat, ratings):
9797
error = rmse(R, ms, us)
9898
print "Iteration %d:" % i
9999
print "\nRMSE: %5.4f\n" % error
100+
101+
sc.stop()

examples/src/main/python/cassandra_inputformat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,5 @@
7777
output = cass_rdd.collect()
7878
for (k, v) in output:
7979
print (k, v)
80+
81+
sc.stop()

examples/src/main/python/cassandra_outputformat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,5 @@
8181
conf=conf,
8282
keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter",
8383
valueConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLValueConverter")
84+
85+
sc.stop()

examples/src/main/python/hbase_inputformat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,5 @@
7171
output = hbase_rdd.collect()
7272
for (k, v) in output:
7373
print (k, v)
74+
75+
sc.stop()

examples/src/main/python/hbase_outputformat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,5 @@
6363
conf=conf,
6464
keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter",
6565
valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter")
66+
67+
sc.stop()

examples/src/main/python/kmeans.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,5 @@ def closestPoint(p, centers):
7777
kPoints[x] = y
7878

7979
print "Final centers: " + str(kPoints)
80+
81+
sc.stop()

examples/src/main/python/logistic_regression.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,5 @@ def add(x, y):
8080
w -= points.map(lambda m: gradient(m, w)).reduce(add)
8181

8282
print "Final w: " + str(w)
83+
84+
sc.stop()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
Correlations using MLlib.
20+
"""
21+
22+
import sys
23+
24+
from pyspark import SparkContext
25+
from pyspark.mllib.regression import LabeledPoint
26+
from pyspark.mllib.stat import Statistics
27+
from pyspark.mllib.util import MLUtils
28+
29+
30+
if __name__ == "__main__":
31+
if len(sys.argv) not in [1,2]:
32+
print >> sys.stderr, "Usage: correlations (<file>)"
33+
exit(-1)
34+
sc = SparkContext(appName="PythonCorrelations")
35+
if len(sys.argv) == 2:
36+
filepath = sys.argv[1]
37+
else:
38+
filepath = 'data/mllib/sample_linear_regression_data.txt'
39+
corrType = 'pearson'
40+
41+
points = MLUtils.loadLibSVMFile(sc, filepath)\
42+
.map(lambda lp: LabeledPoint(lp.label, lp.features.toArray()))
43+
44+
print
45+
print 'Summary of data file: ' + filepath
46+
print '%d data points' % points.count()
47+
48+
# Statistics (correlations)
49+
print
50+
print 'Correlation (%s) between label and each feature' % corrType
51+
print 'Feature\tCorrelation'
52+
numFeatures = points.take(1)[0].features.size
53+
labelRDD = points.map(lambda lp: lp.label)
54+
for i in range(numFeatures):
55+
featureRDD = points.map(lambda lp: lp.features[i])
56+
corr = Statistics.corr(labelRDD, featureRDD, corrType)
57+
print '%d\t%g' % (i, corr)
58+
print
59+
60+
sc.stop()

examples/src/main/python/mllib/decision_tree_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
"""
1919
Decision tree classification and regression using MLlib.
20+
21+
This example requires NumPy (http://www.numpy.org/).
2022
"""
2123

2224
import numpy, os, sys
@@ -117,6 +119,7 @@ def usage():
117119
if len(sys.argv) == 2:
118120
dataPath = sys.argv[1]
119121
if not os.path.isfile(dataPath):
122+
sc.stop()
120123
usage()
121124
points = MLUtils.loadLibSVMFile(sc, dataPath)
122125

@@ -133,3 +136,5 @@ def usage():
133136
print " Model depth: %d\n" % model.depth()
134137
print " Training accuracy: %g\n" % getAccuracy(model, reindexedData)
135138
print model
139+
140+
sc.stop()

examples/src/main/python/mllib/kmeans.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ def parseVector(line):
4242
k = int(sys.argv[2])
4343
model = KMeans.train(data, k)
4444
print "Final centers: " + str(model.clusterCenters)
45+
sc.stop()

0 commit comments

Comments
 (0)