diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 3a0b816c367ec..be1521154f042 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -44,6 +44,7 @@ import numpy as np from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect +import py4j from pyspark import keyword_only, SparkContext from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer @@ -1859,8 +1860,9 @@ class ImageReaderTest2(PySparkTestCase): @classmethod def setUpClass(cls): - PySparkTestCase.setUpClass() + super(ImageReaderTest2, cls).setUpClass() # Note that here we enable Hive's support. + cls.spark = None try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: @@ -1873,8 +1875,10 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - PySparkTestCase.tearDownClass() - cls.spark.sparkSession.stop() + super(ImageReaderTest2, cls).tearDownClass() + if cls.spark is not None: + cls.spark.sparkSession.stop() + cls.spark = None def test_read_images_multiple_times(self): # This test case is to check if `ImageSchema.readImages` tries to