Skip to content

Commit 269c789

Browse files
BryanCutlermstewart141
authored andcommitted
[SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary list
## What changes were proposed in this pull request? Added a class method to construct CountVectorizerModel from a list of vocabulary strings, equivalent to the Scala version. Introduced a common param base class `_CountVectorizerParams` to allow the Python model to also own the parameters. This now matches the Scala class hierarchy. ## How was this patch tested? Added to CountVectorizer doctests to do a transform on a model constructed from vocab, and unit test to verify params and vocab are constructed correctly. Author: Bryan Cutler <[email protected]> Closes apache#16770 from BryanCutler/pyspark-CountVectorizerModel-vocab_ctor-SPARK-15009.
1 parent 43d5f0f commit 269c789

File tree

2 files changed

+142
-58
lines changed

2 files changed

+142
-58
lines changed

python/pyspark/ml/feature.py

Lines changed: 112 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
if sys.version > '3':
2020
basestring = str
2121

22-
from pyspark import since, keyword_only
22+
from pyspark import since, keyword_only, SparkContext
2323
from pyspark.rdd import ignore_unicode_prefix
2424
from pyspark.ml.linalg import _convert_to_vector
2525
from pyspark.ml.param.shared import *
2626
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
27-
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
27+
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
2828
from pyspark.ml.common import inherit_doc
2929

3030
__all__ = ['Binarizer',
@@ -403,8 +403,69 @@ def getSplits(self):
403403
return self.getOrDefault(self.splits)
404404

405405

406+
class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
407+
"""
408+
Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`.
409+
"""
410+
411+
minTF = Param(
412+
Params._dummy(), "minTF", "Filter to ignore rare words in" +
413+
" a document. For each document, terms with frequency/count less than the given" +
414+
" threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
415+
" times the term must appear in the document); if this is a double in [0,1), then this " +
416+
"specifies a fraction (out of the document's token count). Note that the parameter is " +
417+
"only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
418+
typeConverter=TypeConverters.toFloat)
419+
minDF = Param(
420+
Params._dummy(), "minDF", "Specifies the minimum number of" +
421+
" different documents a term must appear in to be included in the vocabulary." +
422+
" If this is an integer >= 1, this specifies the number of documents the term must" +
423+
" appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
424+
" Default 1.0", typeConverter=TypeConverters.toFloat)
425+
vocabSize = Param(
426+
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
427+
typeConverter=TypeConverters.toInt)
428+
binary = Param(
429+
Params._dummy(), "binary", "Binary toggle to control the output vector values." +
430+
" If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
431+
" for discrete probabilistic models that model binary events rather than integer counts." +
432+
" Default False", typeConverter=TypeConverters.toBoolean)
433+
434+
def __init__(self, *args):
435+
super(_CountVectorizerParams, self).__init__(*args)
436+
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
437+
438+
@since("1.6.0")
439+
def getMinTF(self):
440+
"""
441+
Gets the value of minTF or its default value.
442+
"""
443+
return self.getOrDefault(self.minTF)
444+
445+
@since("1.6.0")
446+
def getMinDF(self):
447+
"""
448+
Gets the value of minDF or its default value.
449+
"""
450+
return self.getOrDefault(self.minDF)
451+
452+
@since("1.6.0")
453+
def getVocabSize(self):
454+
"""
455+
Gets the value of vocabSize or its default value.
456+
"""
457+
return self.getOrDefault(self.vocabSize)
458+
459+
@since("2.0.0")
460+
def getBinary(self):
461+
"""
462+
Gets the value of binary or its default value.
463+
"""
464+
return self.getOrDefault(self.binary)
465+
466+
406467
@inherit_doc
407-
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
468+
class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
408469
"""
409470
Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`.
410471
@@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
437498
>>> loadedModel = CountVectorizerModel.load(modelPath)
438499
>>> loadedModel.vocabulary == model.vocabulary
439500
True
501+
>>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],
502+
... inputCol="raw", outputCol="vectors")
503+
>>> fromVocabModel.transform(df).show(truncate=False)
504+
+-----+---------------+-------------------------+
505+
|label|raw |vectors |
506+
+-----+---------------+-------------------------+
507+
|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
508+
|1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
509+
+-----+---------------+-------------------------+
510+
...
440511
441512
.. versionadded:: 1.6.0
442513
"""
443514

444-
minTF = Param(
445-
Params._dummy(), "minTF", "Filter to ignore rare words in" +
446-
" a document. For each document, terms with frequency/count less than the given" +
447-
" threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
448-
" times the term must appear in the document); if this is a double in [0,1), then this " +
449-
"specifies a fraction (out of the document's token count). Note that the parameter is " +
450-
"only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
451-
typeConverter=TypeConverters.toFloat)
452-
minDF = Param(
453-
Params._dummy(), "minDF", "Specifies the minimum number of" +
454-
" different documents a term must appear in to be included in the vocabulary." +
455-
" If this is an integer >= 1, this specifies the number of documents the term must" +
456-
" appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
457-
" Default 1.0", typeConverter=TypeConverters.toFloat)
458-
vocabSize = Param(
459-
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
460-
typeConverter=TypeConverters.toInt)
461-
binary = Param(
462-
Params._dummy(), "binary", "Binary toggle to control the output vector values." +
463-
" If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
464-
" for discrete probabilistic models that model binary events rather than integer counts." +
465-
" Default False", typeConverter=TypeConverters.toBoolean)
466-
467515
@keyword_only
468516
def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
469517
outputCol=None):
@@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC
474522
super(CountVectorizer, self).__init__()
475523
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
476524
self.uid)
477-
self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
478525
kwargs = self._input_kwargs
479526
self.setParams(**kwargs)
480527

@@ -497,66 +544,61 @@ def setMinTF(self, value):
497544
"""
498545
return self._set(minTF=value)
499546

500-
@since("1.6.0")
501-
def getMinTF(self):
502-
"""
503-
Gets the value of minTF or its default value.
504-
"""
505-
return self.getOrDefault(self.minTF)
506-
507547
@since("1.6.0")
508548
def setMinDF(self, value):
509549
"""
510550
Sets the value of :py:attr:`minDF`.
511551
"""
512552
return self._set(minDF=value)
513553

514-
@since("1.6.0")
515-
def getMinDF(self):
516-
"""
517-
Gets the value of minDF or its default value.
518-
"""
519-
return self.getOrDefault(self.minDF)
520-
521554
@since("1.6.0")
522555
def setVocabSize(self, value):
523556
"""
524557
Sets the value of :py:attr:`vocabSize`.
525558
"""
526559
return self._set(vocabSize=value)
527560

528-
@since("1.6.0")
529-
def getVocabSize(self):
530-
"""
531-
Gets the value of vocabSize or its default value.
532-
"""
533-
return self.getOrDefault(self.vocabSize)
534-
535561
@since("2.0.0")
536562
def setBinary(self, value):
537563
"""
538564
Sets the value of :py:attr:`binary`.
539565
"""
540566
return self._set(binary=value)
541567

542-
@since("2.0.0")
543-
def getBinary(self):
544-
"""
545-
Gets the value of binary or its default value.
546-
"""
547-
return self.getOrDefault(self.binary)
548-
549568
def _create_model(self, java_model):
550569
return CountVectorizerModel(java_model)
551570

552571

553-
class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
572+
@inherit_doc
573+
class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable):
554574
"""
555575
Model fitted by :py:class:`CountVectorizer`.
556576
557577
.. versionadded:: 1.6.0
558578
"""
559579

580+
@classmethod
581+
@since("2.4.0")
582+
def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None):
583+
"""
584+
Construct the model directly from a vocabulary list of strings,
585+
requires an active SparkContext.
586+
"""
587+
sc = SparkContext._active_spark_context
588+
java_class = sc._gateway.jvm.java.lang.String
589+
jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
590+
model = CountVectorizerModel._create_from_java_class(
591+
"org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
592+
model.setInputCol(inputCol)
593+
if outputCol is not None:
594+
model.setOutputCol(outputCol)
595+
if minTF is not None:
596+
model.setMinTF(minTF)
597+
if binary is not None:
598+
model.setBinary(binary)
599+
model._set(vocabSize=len(vocabulary))
600+
return model
601+
560602
@property
561603
@since("1.6.0")
562604
def vocabulary(self):
@@ -565,6 +607,20 @@ def vocabulary(self):
565607
"""
566608
return self._call_java("vocabulary")
567609

610+
@since("2.4.0")
611+
def setMinTF(self, value):
612+
"""
613+
Sets the value of :py:attr:`minTF`.
614+
"""
615+
return self._set(minTF=value)
616+
617+
@since("2.4.0")
618+
def setBinary(self, value):
619+
"""
620+
Sets the value of :py:attr:`binary`.
621+
"""
622+
return self._set(binary=value)
623+
568624

569625
@inherit_doc
570626
class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):

python/pyspark/ml/tests.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,34 @@ def test_count_vectorizer_with_binary(self):
679679
feature, expected = r
680680
self.assertEqual(feature, expected)
681681

682+
def test_count_vectorizer_from_vocab(self):
683+
model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words",
684+
outputCol="features", minTF=2)
685+
self.assertEqual(model.vocabulary, ["a", "b", "c"])
686+
self.assertEqual(model.getMinTF(), 2)
687+
688+
dataset = self.spark.createDataFrame([
689+
(0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),),
690+
(1, "a a".split(' '), SparseVector(3, {0: 2.0}),),
691+
(2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"])
692+
693+
transformed_list = model.transform(dataset).select("features", "expected").collect()
694+
695+
for r in transformed_list:
696+
feature, expected = r
697+
self.assertEqual(feature, expected)
698+
699+
# Test an empty vocabulary
700+
with QuietTest(self.sc):
701+
with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
702+
CountVectorizerModel.from_vocabulary([], inputCol="words")
703+
704+
# Test model with default settings can transform
705+
model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words")
706+
transformed_list = model_default.transform(dataset)\
707+
.select(model_default.getOrDefault(model_default.outputCol)).collect()
708+
self.assertEqual(len(transformed_list), 3)
709+
682710
def test_rformula_force_index_label(self):
683711
df = self.spark.createDataFrame([
684712
(1.0, 1.0, "a"),
@@ -2019,8 +2047,8 @@ def test_java_params(self):
20192047
pyspark.ml.regression]
20202048
for module in modules:
20212049
for name, cls in inspect.getmembers(module, inspect.isclass):
2022-
if not name.endswith('Model') and issubclass(cls, JavaParams)\
2023-
and not inspect.isabstract(cls):
2050+
if not name.endswith('Model') and not name.endswith('Params')\
2051+
and issubclass(cls, JavaParams) and not inspect.isabstract(cls):
20242052
# NOTE: disable check_params_exist until there is parity with Scala API
20252053
ParamTests.check_params(self, cls(), check_params_exist=False)
20262054

0 commit comments

Comments
 (0)