Skip to content

Commit e94dde3

Browse files
committed
updated CountVectorizerModel to use common param base class
1 parent 01e5a4b commit e94dde3

File tree

2 files changed

+107
-61
lines changed

2 files changed

+107
-61
lines changed

python/pyspark/ml/feature.py

Lines changed: 90 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pyspark.ml.linalg import _convert_to_vector
2525
from pyspark.ml.param.shared import *
2626
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
27-
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
27+
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
2828
from pyspark.ml.common import inherit_doc
2929

3030
__all__ = ['Binarizer',
@@ -403,8 +403,69 @@ def getSplits(self):
403403
return self.getOrDefault(self.splits)
404404

405405

406+
class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
407+
"""
408+
Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`.
409+
"""
410+
411+
minTF = Param(
412+
Params._dummy(), "minTF", "Filter to ignore rare words in" +
413+
" a document. For each document, terms with frequency/count less than the given" +
414+
" threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
415+
" times the term must appear in the document); if this is a double in [0,1), then this " +
416+
"specifies a fraction (out of the document's token count). Note that the parameter is " +
417+
"only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
418+
typeConverter=TypeConverters.toFloat)
419+
minDF = Param(
420+
Params._dummy(), "minDF", "Specifies the minimum number of" +
421+
" different documents a term must appear in to be included in the vocabulary." +
422+
" If this is an integer >= 1, this specifies the number of documents the term must" +
423+
" appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
424+
" Default 1.0", typeConverter=TypeConverters.toFloat)
425+
vocabSize = Param(
426+
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
427+
typeConverter=TypeConverters.toInt)
428+
binary = Param(
429+
Params._dummy(), "binary", "Binary toggle to control the output vector values." +
430+
" If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
431+
" for discrete probabilistic models that model binary events rather than integer counts." +
432+
" Default False", typeConverter=TypeConverters.toBoolean)
433+
434+
def __init__(self, *args):
435+
super(_CountVectorizerParams, self).__init__(*args)
436+
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
437+
438+
@since("1.6.0")
439+
def getMinTF(self):
440+
"""
441+
Gets the value of minTF or its default value.
442+
"""
443+
return self.getOrDefault(self.minTF)
444+
445+
@since("1.6.0")
446+
def getMinDF(self):
447+
"""
448+
Gets the value of minDF or its default value.
449+
"""
450+
return self.getOrDefault(self.minDF)
451+
452+
@since("1.6.0")
453+
def getVocabSize(self):
454+
"""
455+
Gets the value of vocabSize or its default value.
456+
"""
457+
return self.getOrDefault(self.vocabSize)
458+
459+
@since("2.0.0")
460+
def getBinary(self):
461+
"""
462+
Gets the value of binary or its default value.
463+
"""
464+
return self.getOrDefault(self.binary)
465+
466+
406467
@inherit_doc
407-
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
468+
class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
408469
"""
409470
Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`.
410471
@@ -437,7 +498,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
437498
>>> loadedModel = CountVectorizerModel.load(modelPath)
438499
>>> loadedModel.vocabulary == model.vocabulary
439500
True
440-
>>> fromVocabModel = CountVectorizerModel.fromVocabulary(model.vocabulary,
501+
>>> fromVocabModel = CountVectorizerModel.from_vocabulary(model.vocabulary,
441502
... inputCol="raw", outputCol="vectors")
442503
>>> fromVocabModel.transform(df).show(truncate=False)
443504
+-----+---------------+-------------------------+
@@ -451,29 +512,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
451512
.. versionadded:: 1.6.0
452513
"""
453514

454-
minTF = Param(
455-
Params._dummy(), "minTF", "Filter to ignore rare words in" +
456-
" a document. For each document, terms with frequency/count less than the given" +
457-
" threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
458-
" times the term must appear in the document); if this is a double in [0,1), then this " +
459-
"specifies a fraction (out of the document's token count). Note that the parameter is " +
460-
"only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
461-
typeConverter=TypeConverters.toFloat)
462-
minDF = Param(
463-
Params._dummy(), "minDF", "Specifies the minimum number of" +
464-
" different documents a term must appear in to be included in the vocabulary." +
465-
" If this is an integer >= 1, this specifies the number of documents the term must" +
466-
" appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
467-
" Default 1.0", typeConverter=TypeConverters.toFloat)
468-
vocabSize = Param(
469-
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
470-
typeConverter=TypeConverters.toInt)
471-
binary = Param(
472-
Params._dummy(), "binary", "Binary toggle to control the output vector values." +
473-
" If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
474-
" for discrete probabilistic models that model binary events rather than integer counts." +
475-
" Default False", typeConverter=TypeConverters.toBoolean)
476-
477515
@keyword_only
478516
def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
479517
outputCol=None):
@@ -484,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC
484522
super(CountVectorizer, self).__init__()
485523
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
486524
self.uid)
487-
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
488525
kwargs = self._input_kwargs
489526
self.setParams(**kwargs)
490527

@@ -507,81 +544,59 @@ def setMinTF(self, value):
507544
"""
508545
return self._set(minTF=value)
509546

510-
@since("1.6.0")
511-
def getMinTF(self):
512-
"""
513-
Gets the value of minTF or its default value.
514-
"""
515-
return self.getOrDefault(self.minTF)
516-
517547
@since("1.6.0")
518548
def setMinDF(self, value):
519549
"""
520550
Sets the value of :py:attr:`minDF`.
521551
"""
522552
return self._set(minDF=value)
523553

524-
@since("1.6.0")
525-
def getMinDF(self):
526-
"""
527-
Gets the value of minDF or its default value.
528-
"""
529-
return self.getOrDefault(self.minDF)
530-
531554
@since("1.6.0")
532555
def setVocabSize(self, value):
533556
"""
534557
Sets the value of :py:attr:`vocabSize`.
535558
"""
536559
return self._set(vocabSize=value)
537560

538-
@since("1.6.0")
539-
def getVocabSize(self):
540-
"""
541-
Gets the value of vocabSize or its default value.
542-
"""
543-
return self.getOrDefault(self.vocabSize)
544-
545561
@since("2.0.0")
546562
def setBinary(self, value):
547563
"""
548564
Sets the value of :py:attr:`binary`.
549565
"""
550566
return self._set(binary=value)
551567

552-
@since("2.0.0")
553-
def getBinary(self):
554-
"""
555-
Gets the value of binary or its default value.
556-
"""
557-
return self.getOrDefault(self.binary)
558-
559568
def _create_model(self, java_model):
560569
return CountVectorizerModel(java_model)
561570

562571

563-
class CountVectorizerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
572+
@inherit_doc
573+
class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
564574
"""
565575
Model fitted by :py:class:`CountVectorizer`.
566576
567577
.. versionadded:: 1.6.0
568578
"""
569579

570580
@classmethod
571-
@since("2.2.0")
572-
def fromVocabulary(cls, vocab, inputCol, outputCol=None):
581+
@since("2.4.0")
582+
def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None):
573583
"""
574-
Construct the model directly from a vocabulary list, requires
575-
an active SparkContext.
584+
Construct the model directly from a vocabulary list of strings,
585+
requires an active SparkContext.
576586
"""
577587
sc = SparkContext._active_spark_context
578588
java_class = sc._gateway.jvm.java.lang.String
579-
jvocab = CountVectorizerModel._new_java_array(vocab, java_class)
589+
jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
580590
model = CountVectorizerModel._create_from_java_class(
581591
"org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
582592
model.setInputCol(inputCol)
583593
if outputCol is not None:
584594
model.setOutputCol(outputCol)
595+
if minTF is not None:
596+
model.setMinTF(minTF)
597+
if binary is not None:
598+
model.setBinary(binary)
599+
model._set(vocabSize=len(vocabulary))
585600
return model
586601

587602
@property
@@ -592,6 +607,20 @@ def vocabulary(self):
592607
"""
593608
return self._call_java("vocabulary")
594609

610+
@since("2.4.0")
611+
def setMinTF(self, value):
612+
"""
613+
Sets the value of :py:attr:`minTF`.
614+
"""
615+
return self._set(minTF=value)
616+
617+
@since("2.4.0")
618+
def setBinary(self, value):
619+
"""
620+
Sets the value of :py:attr:`binary`.
621+
"""
622+
return self._set(binary=value)
623+
595624

596625
@inherit_doc
597626
class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):

python/pyspark/ml/tests.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,23 @@ def test_count_vectorizer_with_binary(self):
640640
feature, expected = r
641641
self.assertEqual(feature, expected)
642642

643+
def test_count_vectorizer_from_vocab(self):
644+
model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words",
645+
outputCol="features", minTF=2)
646+
self.assertEqual(model.vocabulary, ["a", "b", "c"])
647+
self.assertEqual(model.getMinTF(), 2)
648+
649+
dataset = self.spark.createDataFrame([
650+
(0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),),
651+
(1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
652+
(2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
653+
654+
transformed_list = model.transform(dataset).select("features", "expected").collect()
655+
656+
for r in transformed_list:
657+
feature, expected = r
658+
self.assertEqual(feature, expected)
659+
643660
def test_rformula_force_index_label(self):
644661
df = self.spark.createDataFrame([
645662
(1.0, 1.0, "a"),

0 commit comments

Comments
 (0)