Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 112 additions & 56 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
if sys.version > '3':
basestring = str

from pyspark import since, keyword_only
from pyspark import since, keyword_only, SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.linalg import _convert_to_vector
from pyspark.ml.param.shared import *
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
from pyspark.ml.common import inherit_doc

__all__ = ['Binarizer',
Expand Down Expand Up @@ -403,8 +403,69 @@ def getSplits(self):
return self.getOrDefault(self.splits)


class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
"""
Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`.
"""

minTF = Param(
Params._dummy(), "minTF", "Filter to ignore rare words in" +
" a document. For each document, terms with frequency/count less than the given" +
" threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
" times the term must appear in the document); if this is a double in [0,1), then this " +
"specifies a fraction (out of the document's token count). Note that the parameter is " +
"only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
typeConverter=TypeConverters.toFloat)
minDF = Param(
Params._dummy(), "minDF", "Specifies the minimum number of" +
" different documents a term must appear in to be included in the vocabulary." +
" If this is an integer >= 1, this specifies the number of documents the term must" +
" appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
" Default 1.0", typeConverter=TypeConverters.toFloat)
vocabSize = Param(
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
typeConverter=TypeConverters.toInt)
binary = Param(
Params._dummy(), "binary", "Binary toggle to control the output vector values." +
" If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
" for discrete probabilistic models that model binary events rather than integer counts." +
" Default False", typeConverter=TypeConverters.toBoolean)

def __init__(self, *args):
super(_CountVectorizerParams, self).__init__(*args)
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)

@since("1.6.0")
def getMinTF(self):
"""
Gets the value of minTF or its default value.
"""
return self.getOrDefault(self.minTF)

@since("1.6.0")
def getMinDF(self):
"""
Gets the value of minDF or its default value.
"""
return self.getOrDefault(self.minDF)

@since("1.6.0")
def getVocabSize(self):
"""
Gets the value of vocabSize or its default value.
"""
return self.getOrDefault(self.vocabSize)

@since("2.0.0")
def getBinary(self):
"""
Gets the value of binary or its default value.
"""
return self.getOrDefault(self.binary)


@inherit_doc
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
"""
Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`.

Expand Down Expand Up @@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
>>> loadedModel = CountVectorizerModel.load(modelPath)
>>> loadedModel.vocabulary == model.vocabulary
True
>>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],
... inputCol="raw", outputCol="vectors")
>>> fromVocabModel.transform(df).show(truncate=False)
+-----+---------------+-------------------------+
|label|raw |vectors |
+-----+---------------+-------------------------+
|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
|1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
+-----+---------------+-------------------------+
...

.. versionadded:: 1.6.0
"""

minTF = Param(
Params._dummy(), "minTF", "Filter to ignore rare words in" +
" a document. For each document, terms with frequency/count less than the given" +
" threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
" times the term must appear in the document); if this is a double in [0,1), then this " +
"specifies a fraction (out of the document's token count). Note that the parameter is " +
"only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
typeConverter=TypeConverters.toFloat)
minDF = Param(
Params._dummy(), "minDF", "Specifies the minimum number of" +
" different documents a term must appear in to be included in the vocabulary." +
" If this is an integer >= 1, this specifies the number of documents the term must" +
" appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
" Default 1.0", typeConverter=TypeConverters.toFloat)
vocabSize = Param(
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
typeConverter=TypeConverters.toInt)
binary = Param(
Params._dummy(), "binary", "Binary toggle to control the output vector values." +
" If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
" for discrete probabilistic models that model binary events rather than integer counts." +
" Default False", typeConverter=TypeConverters.toBoolean)

@keyword_only
def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
outputCol=None):
Expand All @@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC
super(CountVectorizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
self.uid)
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
kwargs = self._input_kwargs
self.setParams(**kwargs)

Expand All @@ -497,66 +544,61 @@ def setMinTF(self, value):
"""
return self._set(minTF=value)

@since("1.6.0")
def getMinTF(self):
"""
Gets the value of minTF or its default value.
"""
return self.getOrDefault(self.minTF)

@since("1.6.0")
def setMinDF(self, value):
"""
Sets the value of :py:attr:`minDF`.
"""
return self._set(minDF=value)

@since("1.6.0")
def getMinDF(self):
"""
Gets the value of minDF or its default value.
"""
return self.getOrDefault(self.minDF)

@since("1.6.0")
def setVocabSize(self, value):
"""
Sets the value of :py:attr:`vocabSize`.
"""
return self._set(vocabSize=value)

@since("1.6.0")
def getVocabSize(self):
"""
Gets the value of vocabSize or its default value.
"""
return self.getOrDefault(self.vocabSize)

@since("2.0.0")
def setBinary(self, value):
"""
Sets the value of :py:attr:`binary`.
"""
return self._set(binary=value)

@since("2.0.0")
def getBinary(self):
"""
Gets the value of binary or its default value.
"""
return self.getOrDefault(self.binary)

def _create_model(self, java_model):
return CountVectorizerModel(java_model)


class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
@inherit_doc
class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
"""
Model fitted by :py:class:`CountVectorizer`.

.. versionadded:: 1.6.0
"""

@classmethod
@since("2.4.0")
def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None):
"""
Construct the model directly from a vocabulary list of strings,
requires an active SparkContext.
"""
sc = SparkContext._active_spark_context
java_class = sc._gateway.jvm.java.lang.String
jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
model = CountVectorizerModel._create_from_java_class(
"org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
model.setInputCol(inputCol)
if outputCol is not None:
model.setOutputCol(outputCol)
if minTF is not None:
model.setMinTF(minTF)
if binary is not None:
model.setBinary(binary)
model._set(vocabSize=len(vocabulary))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for _set rather than set?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference is set checks to make sure the param is valid, which isn't really needed since this is internal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

return model

@property
@since("1.6.0")
def vocabulary(self):
Expand All @@ -565,6 +607,20 @@ def vocabulary(self):
"""
return self._call_java("vocabulary")

@since("2.4.0")
def setMinTF(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to have the setters in both the model and the estimator maybe we should consider putting it in the shared params class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree but I was trying to match the Scala API. My only thought is it was done this way to leave it up to the implementations if they allow setting the params. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me.

"""
Sets the value of :py:attr:`minTF`.
"""
return self._set(minTF=value)

@since("2.4.0")
def setBinary(self, value):
"""
Sets the value of :py:attr:`binary`.
"""
return self._set(binary=value)


@inherit_doc
class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
Expand Down
32 changes: 30 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,34 @@ def test_count_vectorizer_with_binary(self):
feature, expected = r
self.assertEqual(feature, expected)

def test_count_vectorizer_from_vocab(self):
model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good first test, I'd love to also see it with empty vocab, and also one that uses the default values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good idea, I'll add those

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

outputCol="features", minTF=2)
self.assertEqual(model.vocabulary, ["a", "b", "c"])
self.assertEqual(model.getMinTF(), 2)

dataset = self.spark.createDataFrame([
(0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),),
(1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
(2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])

transformed_list = model.transform(dataset).select("features", "expected").collect()

for r in transformed_list:
feature, expected = r
self.assertEqual(feature, expected)

# Test an empty vocabulary
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
CountVectorizerModel.from_vocabulary([], inputCol="words")

# Test model with default settings can transform
model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words")
transformed_list = model_default.transform(dataset)\
.select(model_default.getOrDefault(model_default.outputCol)).collect()
self.assertEqual(len(transformed_list), 3)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doctest uses default values for all params except outputCol and checks the transformed values, so this is really just testing that nothing fails if all param default values are used including outputCol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm


def test_rformula_force_index_label(self):
df = self.spark.createDataFrame([
(1.0, 1.0, "a"),
Expand Down Expand Up @@ -1980,8 +2008,8 @@ def test_java_params(self):
pyspark.ml.regression]
for module in modules:
for name, cls in inspect.getmembers(module, inspect.isclass):
if not name.endswith('Model') and issubclass(cls, JavaParams)\
and not inspect.isabstract(cls):
if not name.endswith('Model') and not name.endswith('Params')\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I've understood whats happening here, were avoiding doing the default params test on non-concrete classes like the base params shared between the model and the estimator and instead testing just the model and estimator on their own right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's pretty much right but this is only checking estimators and skips models also. We should have an explicit check for CountVectorizer.from_vocabulary here too since that is possible. Unfortunately, a new param maxDF was added to Scala recently and the param check will fail. Once that is in Python, we can add the check for it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. I look forward to us automatically catching models with missing params eventually as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm helping get maxDF in python now, so after that's done I'll make a followup to add this

and issubclass(cls, JavaParams) and not inspect.isabstract(cls):
# NOTE: disable check_params_exist until there is parity with Scala API
ParamTests.check_params(self, cls(), check_params_exist=False)

Expand Down