diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 116885969345..6dee6938d891 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fbc5b5e..5061f6434794 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ def __init__(self, java_obj=None): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and self._java_obj is not None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair.