-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-6263][MLLIB] Python MLlib API missing items: Utils #5707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
2980569
c728046
64f72ad
44295c2
a353354
454c73d
62a9c7e
d254be7
b8b5ef7
c345a44
5d555b1
d6bd416
1502d13
7ec04db
25d3c9d
e32eb40
b29e2bc
1d4714b
3a12a2d
9c329d8
d2aa2a0
6084e9c
3fc27e7
16863ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,6 +71,14 @@ private[python] class PythonMLLibAPI extends Serializable { | |
| minPartitions: Int): JavaRDD[LabeledPoint] = | ||
| MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) | ||
|
|
||
| def appendBias(data: org.apache.spark.mllib.linalg.Vector) | ||
| : org.apache.spark.mllib.linalg.Vector | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scala style: Please follow [https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide] Also, is "Vector" not already imported? |
||
| = MLUtils.appendBias(data) | ||
|
|
||
| def loadVectors(jsc: JavaSparkContext, path: String) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a little doc and fix Scala style here too |
||
| : RDD[org.apache.spark.mllib.linalg.Vector] | ||
| = MLUtils.loadVectors(jsc.sc, path) | ||
|
|
||
| private def trainRegressionModel( | ||
| learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], | ||
| data: JavaRDD[LabeledPoint], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ | |
| from pyspark.mllib.feature import Word2Vec | ||
| from pyspark.mllib.feature import IDF | ||
| from pyspark.mllib.feature import StandardScaler | ||
| from pyspark.mllib.util import MLUtils | ||
| from pyspark.serializers import PickleSerializer | ||
| from pyspark.sql import SQLContext | ||
|
|
||
|
|
@@ -789,6 +790,29 @@ def test_model_transform(self): | |
| self.assertEqual(model.transform([1.0, 2.0, 3.0]), DenseVector([1.0, 2.0, 3.0])) | ||
|
|
||
|
|
||
| class MLUtilsTests(MLlibTestCase): | ||
| def test_append_bias(self): | ||
| data = [1.0, 2.0, 3.0] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First value should not be 1.0 since it could be confused with the appended bias |
||
| ret = MLUtils.appendBias(data) | ||
| self.assertEqual(ret[3], 1.0) | ||
|
|
||
| def test_load_vectors(self): | ||
| import shutil | ||
| data = [ | ||
| [1.0, 2.0, 3.0], | ||
| [1.0, 2.0, 3.0] | ||
| ] | ||
| try: | ||
| self.sc.parallelize(data).saveAsTextFile("test_load_vectors") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use a temp directory, as in ListTests.test_classification: |
||
| ret_rdd = MLUtils.loadVectors(self.sc, "test_load_vectors") | ||
| ret = ret_rdd.collect() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Order of collect() is not guaranteed, so please sort "ret" and "data" and then compare to make the test robust.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, I guess this didn't matter; I didn't notice the vectors were identical. Fine to keep it sorted though. |
||
| self.assertEqual(len(ret), 2) | ||
| self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) | ||
| self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) | ||
| finally: | ||
| shutil.rmtree("test_load_vectors") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| if not _have_scipy: | ||
| print("NOTE: Skipping SciPy tests as it does not seem to be installed") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -169,6 +169,14 @@ def loadLabeledPoints(sc, path, minPartitions=None): | |
| minPartitions = minPartitions or min(sc.defaultParallelism, 2) | ||
| return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) | ||
|
|
||
| @staticmethod | ||
| def appendBias(data): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the Scala version only operates on individual vectors, this one should not be a wrapper; it should do everything in Python. The reason is that callMLlibFunc requires the SparkContext and needs to operate on the driver. But since appendBias operates per-Row, it needs to be called on workers. Also, please add doc. Feel free to copy from Scala doc. |
||
| return callMLlibFunc("appendBias", _convert_to_vector(data)) | ||
|
|
||
| @staticmethod | ||
| def loadVectors(sc, path): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add doc. Feel free to copy from Scala doc |
||
| return callMLlibFunc("loadVectors", sc, path) | ||
|
|
||
|
|
||
| class Saveable(object): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(not needed: see comment for appendBias in util.py)