Skip to content

Commit 21d4fe3

Browse files
committed
Fix init logic of weights
1 parent 0745a30 commit 21d4fe3

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

python/pyspark/mllib/regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
135135
first = data.first()
136136
if not isinstance(first, LabeledPoint):
137137
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
138-
initial_weights = initial_weights or [0.0] * len(data.first().features)
138+
if initial_weights == None:
139+
initial_weights = [0.0] * len(data.first().features)
139140
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
140141
return modelClass(weights, intercept)
141142

python/pyspark/mllib/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,12 @@ def test_regression(self):
323323
self.assertTrue(gbt_model.predict(features[2]) <= 0)
324324
self.assertTrue(gbt_model.predict(features[3]) > 0)
325325

326+
try:
327+
LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
328+
LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
329+
RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]))
330+
except ValueError:
331+
self.fail()
326332

327333
class StatTests(PySparkTestCase):
328334
# SPARK-4023

0 commit comments

Comments
 (0)