Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def test_property(self):
raise RuntimeError("Test property to raise error when invoked")


class ParamTests(PySparkTestCase):
class ParamTests(SparkSessionTestCase):

def test_copy_new_parent(self):
testParams = TestParams()
Expand Down Expand Up @@ -514,6 +514,24 @@ def test_logistic_regression_check_thresholds(self):
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
)

def test_preserve_set_state(self):
dataset = self.spark.createDataFrame([(0.5,)], ["data"])
binarizer = Binarizer(inputCol="data")
self.assertFalse(binarizer.isSet("threshold"))
binarizer.transform(dataset)
binarizer._transfer_params_from_java()
self.assertFalse(binarizer.isSet("threshold"),
"Params not explicitly set should remain unset after transform")

def test_default_params_transferred(self):
dataset = self.spark.createDataFrame([(0.5,)], ["data"])
binarizer = Binarizer(inputCol="data")
# intentionally change the pyspark default, but don't set it
binarizer._defaultParamMap[binarizer.outputCol] = "my_default"
result = binarizer.transform(dataset).select("my_default").collect()
self.assertFalse(binarizer.isSet(binarizer.outputCol))
self.assertEqual(result[0][0], 1.0)

@staticmethod
def check_params(test_self, py_stage, check_params_exist=True):
"""
Expand Down
13 changes: 10 additions & 3 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,18 @@ def _transfer_params_to_java(self):
"""
Transforms the embedded params to the companion Java object.
"""
paramMap = self.extractParamMap()
pair_defaults = []
for param in self.params:
if param in paramMap:
pair = self._make_java_param_pair(param, paramMap[param])
if self.isSet(param):
pair = self._make_java_param_pair(param, self._paramMap[param])
self._java_obj.set(pair)
if self.hasDefault(param):
pair = self._make_java_param_pair(param, self._defaultParamMap[param])
pair_defaults.append(pair)
if len(pair_defaults) > 0:
sc = SparkContext._active_spark_context
pair_defaults_seq = sc._jvm.PythonUtils.toSeq(pair_defaults)
self._java_obj.setDefault(pair_defaults_seq)
Copy link
Member

@viirya viirya Mar 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If java side and python side the default params are the same, do we still need to set default params for the java object? Are't they already set in java object if they are default params?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My take is that while they should be the same, it's still possible they might not be. The user could extend their own classes or it's quite easy to change in Python. Although we don't really support this, if there was a mismatch the user would probably just get bad results and it would be really hard to figure out why. From the Python API, it would look like it was one value but actually using another in Scala.

If you all think it's overly cautious to do this, I can take it out. I just thought it would be cheap insurance to just set these values regardless.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is reasonable, a few extra lines to avoid potential unwanted user surprise is worth it.


def _transfer_param_map_to_java(self, pyParamMap):
"""
Expand Down