Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake):
pass


class JavaWrapperMemoryTests(SparkSessionTestCase):

def test_java_object_gets_detached(self):
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
(0.0, 2.0, Vectors.sparse(1, [], []))],
["label", "weight", "features"])
lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight",
fitIntercept=False)

model = lr.fit(df)
summary = model.summary

self.assertIsInstance(model, JavaWrapper)
self.assertIsInstance(summary, JavaWrapper)
self.assertIsInstance(model, JavaParams)
self.assertNotIsInstance(summary, JavaParams)

error_no_object = 'Target Object ID does not exist for this gateway'

self.assertIn("LinearRegression_", model._java_obj.toString())
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())

model.__del__()

with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())

try:
summary.__del__()
except:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need try...catch here ?
Does it mean when call summary.__del__ after model.__del__ will raise exception ?
Will summary be destroyed at the same time when model.__del__() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__del__ is not a method of the object class. This test throws an error with earlier code (when __del__ is in JavaParams) because the LinearRegressionSummary class did not inherit the del method from any of its ancestors (JavaWrapper and object). After moving the del method to JavaWrapper this line executes. If I remove the try method, then we are testing the condition that "__del__ method exists && __del__ method releases memory".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarified offline: This was needed to run the test before the fix in wrapper.py to verify that it failed before the fix.

pass

with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
summary._java_obj.toString()


class ParamTypeConversionTests(PySparkTestCase):
"""
Test that param type conversion happens.
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def __init__(self, java_obj=None):
super(JavaWrapper, self).__init__()
self._java_obj = java_obj

def __del__(self):
if SparkContext._active_spark_context and self._java_obj is not None:
SparkContext._active_spark_context._gateway.detach(self._java_obj)

@classmethod
def _create_from_java_class(cls, java_class, *args):
"""
Expand Down Expand Up @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params):

__metaclass__ = ABCMeta

def __del__(self):
if SparkContext._active_spark_context:
SparkContext._active_spark_context._gateway.detach(self._java_obj)

def _make_java_param_pair(self, param, value):
"""
Makes a Java param pair.
Expand Down