Skip to content

Commit 257cde7

Browse files
Lewuathemengxr
authored andcommitted
[SPARK-6421][MLLIB] _regression_train_wrapper does not test initialWeights correctly
Weight parameters must be initialized correctly even when numpy array is passed as initial weights. Author: lewuathe <[email protected]> Closes apache#5101 from Lewuathe/SPARK-6421 and squashes the following commits: 7795201 [lewuathe] Fix lint-python errors 21d4fe3 [lewuathe] Fix init logic of weights
1 parent 11e0259 commit 257cde7

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

python/pyspark/mllib/regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
163163
first = data.first()
164164
if not isinstance(first, LabeledPoint):
165165
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
166-
initial_weights = initial_weights or [0.0] * len(data.first().features)
166+
if initial_weights is None:
167+
initial_weights = [0.0] * len(data.first().features)
167168
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
168169
return modelClass(weights, intercept)
169170

python/pyspark/mllib/tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,13 @@ def test_regression(self):
323323
self.assertTrue(gbt_model.predict(features[2]) <= 0)
324324
self.assertTrue(gbt_model.predict(features[3]) > 0)
325325

326+
try:
327+
LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
328+
LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
329+
RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
330+
except ValueError:
331+
self.fail()
332+
326333

327334
class StatTests(PySparkTestCase):
328335
# SPARK-4023

0 commit comments

Comments
 (0)