diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6b4376cbf14e..c2c4861e2aff 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -2096,6 +2096,11 @@ def test_java_params(self): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) + # Additional classes that need explicit construction + from pyspark.ml.feature import CountVectorizerModel + ParamTests.check_params(self, CountVectorizerModel.from_vocabulary(['a'], 'input'), + check_params_exist=False) + def _squared_distance(a, b): if isinstance(a, Vector):