Skip to content

Commit 7706eea

Browse files
Yogesh Gargjkbradley
authored andcommitted
[SPARK-18630][PYTHON][ML] Move del method from JavaParams to JavaWrapper; add tests
The `__del__` method that explicitly detaches the object was moved from `JavaParams` to `JavaWrapper` class, this way model summaries could also be garbage collected in Java. A test case was added to make sure that relevant error messages are thrown after the objects are deleted. I ran pyspark tests agains `pyspark-ml` module `./python/run-tests --python-executables=$(which python) --modules=pyspark-ml` Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Closes #20724 from yogeshg/java_wrapper_memory.
1 parent 5085739 commit 7706eea

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

python/pyspark/ml/tests.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake):
173173
pass
174174

175175

176+
class JavaWrapperMemoryTests(SparkSessionTestCase):
177+
178+
def test_java_object_gets_detached(self):
179+
df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
180+
(0.0, 2.0, Vectors.sparse(1, [], []))],
181+
["label", "weight", "features"])
182+
lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight",
183+
fitIntercept=False)
184+
185+
model = lr.fit(df)
186+
summary = model.summary
187+
188+
self.assertIsInstance(model, JavaWrapper)
189+
self.assertIsInstance(summary, JavaWrapper)
190+
self.assertIsInstance(model, JavaParams)
191+
self.assertNotIsInstance(summary, JavaParams)
192+
193+
error_no_object = 'Target Object ID does not exist for this gateway'
194+
195+
self.assertIn("LinearRegression_", model._java_obj.toString())
196+
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
197+
198+
model.__del__()
199+
200+
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
201+
model._java_obj.toString()
202+
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
203+
204+
try:
205+
summary.__del__()
206+
except:
207+
pass
208+
209+
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
210+
model._java_obj.toString()
211+
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
212+
summary._java_obj.toString()
213+
214+
176215
class ParamTypeConversionTests(PySparkTestCase):
177216
"""
178217
Test that param type conversion happens.

python/pyspark/ml/wrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def __init__(self, java_obj=None):
3636
super(JavaWrapper, self).__init__()
3737
self._java_obj = java_obj
3838

39+
def __del__(self):
40+
if SparkContext._active_spark_context and self._java_obj is not None:
41+
SparkContext._active_spark_context._gateway.detach(self._java_obj)
42+
3943
@classmethod
4044
def _create_from_java_class(cls, java_class, *args):
4145
"""
@@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params):
100104

101105
__metaclass__ = ABCMeta
102106

103-
def __del__(self):
104-
if SparkContext._active_spark_context:
105-
SparkContext._active_spark_context._gateway.detach(self._java_obj)
106-
107107
def _make_java_param_pair(self, param, value):
108108
"""
109109
Makes a Java param pair.

0 commit comments

Comments
 (0)