2424from pyspark .ml .linalg import _convert_to_vector
2525from pyspark .ml .param .shared import *
2626from pyspark .ml .util import JavaMLReadable , JavaMLWritable
27- from pyspark .ml .wrapper import JavaEstimator , JavaModel , JavaTransformer , _jvm
27+ from pyspark .ml .wrapper import JavaEstimator , JavaModel , JavaParams , JavaTransformer , _jvm
2828from pyspark .ml .common import inherit_doc
2929
3030__all__ = ['Binarizer' ,
@@ -403,8 +403,69 @@ def getSplits(self):
403403 return self .getOrDefault (self .splits )
404404
405405
406+ class _CountVectorizerParams (JavaParams , HasInputCol , HasOutputCol ):
407+ """
408+ Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`.
409+ """
410+
411+ minTF = Param (
412+ Params ._dummy (), "minTF" , "Filter to ignore rare words in" +
413+ " a document. For each document, terms with frequency/count less than the given" +
414+ " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
415+ " times the term must appear in the document); if this is a double in [0,1), then this " +
416+ "specifies a fraction (out of the document's token count). Note that the parameter is " +
417+ "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0" ,
418+ typeConverter = TypeConverters .toFloat )
419+ minDF = Param (
420+ Params ._dummy (), "minDF" , "Specifies the minimum number of" +
421+ " different documents a term must appear in to be included in the vocabulary." +
422+ " If this is an integer >= 1, this specifies the number of documents the term must" +
423+ " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
424+ " Default 1.0" , typeConverter = TypeConverters .toFloat )
425+ vocabSize = Param (
426+ Params ._dummy (), "vocabSize" , "max size of the vocabulary. Default 1 << 18." ,
427+ typeConverter = TypeConverters .toInt )
428+ binary = Param (
429+ Params ._dummy (), "binary" , "Binary toggle to control the output vector values." +
430+ " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
431+ " for discrete probabilistic models that model binary events rather than integer counts." +
432+ " Default False" , typeConverter = TypeConverters .toBoolean )
433+
434+ def __init__ (self , * args ):
435+ super (_CountVectorizerParams , self ).__init__ (* args )
436+ self ._setDefault (minTF = 1.0 , minDF = 1.0 , vocabSize = 1 << 18 , binary = False )
437+
438+ @since ("1.6.0" )
439+ def getMinTF (self ):
440+ """
441+ Gets the value of minTF or its default value.
442+ """
443+ return self .getOrDefault (self .minTF )
444+
445+ @since ("1.6.0" )
446+ def getMinDF (self ):
447+ """
448+ Gets the value of minDF or its default value.
449+ """
450+ return self .getOrDefault (self .minDF )
451+
452+ @since ("1.6.0" )
453+ def getVocabSize (self ):
454+ """
455+ Gets the value of vocabSize or its default value.
456+ """
457+ return self .getOrDefault (self .vocabSize )
458+
459+ @since ("2.0.0" )
460+ def getBinary (self ):
461+ """
462+ Gets the value of binary or its default value.
463+ """
464+ return self .getOrDefault (self .binary )
465+
466+
406467@inherit_doc
407- class CountVectorizer (JavaEstimator , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
468+ class CountVectorizer (JavaEstimator , _CountVectorizerParams , JavaMLReadable , JavaMLWritable ):
408469 """
409470 Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`.
410471
@@ -437,7 +498,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
437498 >>> loadedModel = CountVectorizerModel.load(modelPath)
438499 >>> loadedModel.vocabulary == model.vocabulary
439500 True
440- >>> fromVocabModel = CountVectorizerModel.fromVocabulary (model.vocabulary,
501+ >>> fromVocabModel = CountVectorizerModel.from_vocabulary (model.vocabulary,
441502 ... inputCol="raw", outputCol="vectors")
442503 >>> fromVocabModel.transform(df).show(truncate=False)
443504 +-----+---------------+-------------------------+
@@ -451,29 +512,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
451512 .. versionadded:: 1.6.0
452513 """
453514
454- minTF = Param (
455- Params ._dummy (), "minTF" , "Filter to ignore rare words in" +
456- " a document. For each document, terms with frequency/count less than the given" +
457- " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
458- " times the term must appear in the document); if this is a double in [0,1), then this " +
459- "specifies a fraction (out of the document's token count). Note that the parameter is " +
460- "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0" ,
461- typeConverter = TypeConverters .toFloat )
462- minDF = Param (
463- Params ._dummy (), "minDF" , "Specifies the minimum number of" +
464- " different documents a term must appear in to be included in the vocabulary." +
465- " If this is an integer >= 1, this specifies the number of documents the term must" +
466- " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
467- " Default 1.0" , typeConverter = TypeConverters .toFloat )
468- vocabSize = Param (
469- Params ._dummy (), "vocabSize" , "max size of the vocabulary. Default 1 << 18." ,
470- typeConverter = TypeConverters .toInt )
471- binary = Param (
472- Params ._dummy (), "binary" , "Binary toggle to control the output vector values." +
473- " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
474- " for discrete probabilistic models that model binary events rather than integer counts." +
475- " Default False" , typeConverter = TypeConverters .toBoolean )
476-
477515 @keyword_only
478516 def __init__ (self , minTF = 1.0 , minDF = 1.0 , vocabSize = 1 << 18 , binary = False , inputCol = None ,
479517 outputCol = None ):
@@ -484,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC
484522 super (CountVectorizer , self ).__init__ ()
485523 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.CountVectorizer" ,
486524 self .uid )
487- self ._setDefault (minTF = 1.0 , minDF = 1.0 , vocabSize = 1 << 18 , binary = False )
488525 kwargs = self ._input_kwargs
489526 self .setParams (** kwargs )
490527
@@ -507,81 +544,59 @@ def setMinTF(self, value):
507544 """
508545 return self ._set (minTF = value )
509546
510- @since ("1.6.0" )
511- def getMinTF (self ):
512- """
513- Gets the value of minTF or its default value.
514- """
515- return self .getOrDefault (self .minTF )
516-
517547 @since ("1.6.0" )
518548 def setMinDF (self , value ):
519549 """
520550 Sets the value of :py:attr:`minDF`.
521551 """
522552 return self ._set (minDF = value )
523553
524- @since ("1.6.0" )
525- def getMinDF (self ):
526- """
527- Gets the value of minDF or its default value.
528- """
529- return self .getOrDefault (self .minDF )
530-
531554 @since ("1.6.0" )
532555 def setVocabSize (self , value ):
533556 """
534557 Sets the value of :py:attr:`vocabSize`.
535558 """
536559 return self ._set (vocabSize = value )
537560
538- @since ("1.6.0" )
539- def getVocabSize (self ):
540- """
541- Gets the value of vocabSize or its default value.
542- """
543- return self .getOrDefault (self .vocabSize )
544-
545561 @since ("2.0.0" )
546562 def setBinary (self , value ):
547563 """
548564 Sets the value of :py:attr:`binary`.
549565 """
550566 return self ._set (binary = value )
551567
552- @since ("2.0.0" )
553- def getBinary (self ):
554- """
555- Gets the value of binary or its default value.
556- """
557- return self .getOrDefault (self .binary )
558-
559568 def _create_model (self , java_model ):
560569 return CountVectorizerModel (java_model )
561570
562571
563- class CountVectorizerModel (JavaModel , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
572+ @inherit_doc
573+ class CountVectorizerModel (JavaModel , _CountVectorizerParams , JavaMLReadable , JavaMLWritable ):
564574 """
565575 Model fitted by :py:class:`CountVectorizer`.
566576
567577 .. versionadded:: 1.6.0
568578 """
569579
570580 @classmethod
571- @since ("2.2 .0" )
572- def fromVocabulary (cls , vocab , inputCol , outputCol = None ):
581+ @since ("2.4 .0" )
582+ def from_vocabulary (cls , vocabulary , inputCol , outputCol = None , minTF = None , binary = None ):
573583 """
574- Construct the model directly from a vocabulary list, requires
575- an active SparkContext.
584+ Construct the model directly from a vocabulary list of strings,
585+ requires an active SparkContext.
576586 """
577587 sc = SparkContext ._active_spark_context
578588 java_class = sc ._gateway .jvm .java .lang .String
579- jvocab = CountVectorizerModel ._new_java_array (vocab , java_class )
589+ jvocab = CountVectorizerModel ._new_java_array (vocabulary , java_class )
580590 model = CountVectorizerModel ._create_from_java_class (
581591 "org.apache.spark.ml.feature.CountVectorizerModel" , jvocab )
582592 model .setInputCol (inputCol )
583593 if outputCol is not None :
584594 model .setOutputCol (outputCol )
595+ if minTF is not None :
596+ model .setMinTF (minTF )
597+ if binary is not None :
598+ model .setBinary (binary )
599+ model ._set (vocabSize = len (vocabulary ))
585600 return model
586601
587602 @property
@@ -592,6 +607,20 @@ def vocabulary(self):
592607 """
593608 return self ._call_java ("vocabulary" )
594609
610+ @since ("2.4.0" )
611+ def setMinTF (self , value ):
612+ """
613+ Sets the value of :py:attr:`minTF`.
614+ """
615+ return self ._set (minTF = value )
616+
617+ @since ("2.4.0" )
618+ def setBinary (self , value ):
619+ """
620+ Sets the value of :py:attr:`binary`.
621+ """
622+ return self ._set (binary = value )
623+
595624
596625@inherit_doc
597626class DCT (JavaTransformer , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
0 commit comments