Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,24 +256,33 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
vocabSize = Param(
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
typeConverter=TypeConverters.toInt)
binary = Param(
Params._dummy(), "binary", "Binary toggle to control the output vector values." +
" If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
" for discrete probabilistic models that model binary events rather than integer counts." +
" Default False", typeConverter=TypeConverters.toBoolean)

@keyword_only
def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None):
def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
outputCol=None):
"""
__init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None)
__init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
outputCol=None)
"""
super(CountVectorizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
self.uid)
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18)
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("1.6.0")
def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None):
def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
outputCol=None):
"""
setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None)
setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
outputCol=None)
Set the params for the CountVectorizer
"""
kwargs = self.setParams._input_kwargs
Expand Down Expand Up @@ -324,6 +333,21 @@ def getVocabSize(self):
"""
return self.getOrDefault(self.vocabSize)

@since("2.0.0")
def setBinary(self, value):
"""
Sets the value of :py:attr:`binary`.
"""
self._paramMap[self.binary] = value
return self

@since("2.0.0")
def getBinary(self):
"""
Gets the value of binary or its default value.
"""
return self.getOrDefault(self.binary)

def _create_model(self, java_model):
return CountVectorizerModel(java_model)

Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,22 @@ def test_stopwordsremover(self):
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])

def test_count_vectorizer_with_binary(self):
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame([
(0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
(1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
(2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
(3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"])
cv = CountVectorizer(binary=True, inputCol="words", outputCol="features")
model = cv.fit(dataset)

transformedList = model.transform(dataset).select("features", "expected").collect()

for r in transformedList:
feature, expected = r
self.assertEqual(feature, expected)


class HasInducedError(Params):

Expand Down