1919if sys .version > '3' :
2020 basestring = str
2121
22- from pyspark import since , keyword_only
22+ from pyspark import since , keyword_only , SparkContext
2323from pyspark .rdd import ignore_unicode_prefix
2424from pyspark .ml .linalg import _convert_to_vector
2525from pyspark .ml .param .shared import *
2626from pyspark .ml .util import JavaMLReadable , JavaMLWritable
27- from pyspark .ml .wrapper import JavaEstimator , JavaModel , JavaTransformer , _jvm
27+ from pyspark .ml .wrapper import JavaEstimator , JavaModel , JavaParams , JavaTransformer , _jvm
2828from pyspark .ml .common import inherit_doc
2929
3030__all__ = ['Binarizer' ,
@@ -403,8 +403,69 @@ def getSplits(self):
403403 return self .getOrDefault (self .splits )
404404
405405
406+ class _CountVectorizerParams (JavaParams , HasInputCol , HasOutputCol ):
407+ """
408+ Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`.
409+ """
410+
411+ minTF = Param (
412+ Params ._dummy (), "minTF" , "Filter to ignore rare words in" +
413+ " a document. For each document, terms with frequency/count less than the given" +
414+ " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
415+ " times the term must appear in the document); if this is a double in [0,1), then this " +
416+ "specifies a fraction (out of the document's token count). Note that the parameter is " +
417+ "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0" ,
418+ typeConverter = TypeConverters .toFloat )
419+ minDF = Param (
420+ Params ._dummy (), "minDF" , "Specifies the minimum number of" +
421+ " different documents a term must appear in to be included in the vocabulary." +
422+ " If this is an integer >= 1, this specifies the number of documents the term must" +
423+ " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
424+ " Default 1.0" , typeConverter = TypeConverters .toFloat )
425+ vocabSize = Param (
426+ Params ._dummy (), "vocabSize" , "max size of the vocabulary. Default 1 << 18." ,
427+ typeConverter = TypeConverters .toInt )
428+ binary = Param (
429+ Params ._dummy (), "binary" , "Binary toggle to control the output vector values." +
430+ " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
431+ " for discrete probabilistic models that model binary events rather than integer counts." +
432+ " Default False" , typeConverter = TypeConverters .toBoolean )
433+
434+ def __init__ (self , * args ):
435+ super (_CountVectorizerParams , self ).__init__ (* args )
436+ self ._setDefault (minTF = 1.0 , minDF = 1.0 , vocabSize = 1 << 18 , binary = False )
437+
438+ @since ("1.6.0" )
439+ def getMinTF (self ):
440+ """
441+ Gets the value of minTF or its default value.
442+ """
443+ return self .getOrDefault (self .minTF )
444+
445+ @since ("1.6.0" )
446+ def getMinDF (self ):
447+ """
448+ Gets the value of minDF or its default value.
449+ """
450+ return self .getOrDefault (self .minDF )
451+
452+ @since ("1.6.0" )
453+ def getVocabSize (self ):
454+ """
455+ Gets the value of vocabSize or its default value.
456+ """
457+ return self .getOrDefault (self .vocabSize )
458+
459+ @since ("2.0.0" )
460+ def getBinary (self ):
461+ """
462+ Gets the value of binary or its default value.
463+ """
464+ return self .getOrDefault (self .binary )
465+
466+
406467@inherit_doc
407- class CountVectorizer (JavaEstimator , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
468+ class CountVectorizer (JavaEstimator , _CountVectorizerParams , JavaMLReadable , JavaMLWritable ):
408469 """
409470 Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`.
410471
@@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
437498 >>> loadedModel = CountVectorizerModel.load(modelPath)
438499 >>> loadedModel.vocabulary == model.vocabulary
439500 True
501+ >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"],
502+ ... inputCol="raw", outputCol="vectors")
503+ >>> fromVocabModel.transform(df).show(truncate=False)
504+ +-----+---------------+-------------------------+
505+ |label|raw |vectors |
506+ +-----+---------------+-------------------------+
507+ |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
508+ |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
509+ +-----+---------------+-------------------------+
510+ ...
440511
441512 .. versionadded:: 1.6.0
442513 """
443514
444- minTF = Param (
445- Params ._dummy (), "minTF" , "Filter to ignore rare words in" +
446- " a document. For each document, terms with frequency/count less than the given" +
447- " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
448- " times the term must appear in the document); if this is a double in [0,1), then this " +
449- "specifies a fraction (out of the document's token count). Note that the parameter is " +
450- "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0" ,
451- typeConverter = TypeConverters .toFloat )
452- minDF = Param (
453- Params ._dummy (), "minDF" , "Specifies the minimum number of" +
454- " different documents a term must appear in to be included in the vocabulary." +
455- " If this is an integer >= 1, this specifies the number of documents the term must" +
456- " appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
457- " Default 1.0" , typeConverter = TypeConverters .toFloat )
458- vocabSize = Param (
459- Params ._dummy (), "vocabSize" , "max size of the vocabulary. Default 1 << 18." ,
460- typeConverter = TypeConverters .toInt )
461- binary = Param (
462- Params ._dummy (), "binary" , "Binary toggle to control the output vector values." +
463- " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
464- " for discrete probabilistic models that model binary events rather than integer counts." +
465- " Default False" , typeConverter = TypeConverters .toBoolean )
466-
467515 @keyword_only
468516 def __init__ (self , minTF = 1.0 , minDF = 1.0 , vocabSize = 1 << 18 , binary = False , inputCol = None ,
469517 outputCol = None ):
@@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC
474522 super (CountVectorizer , self ).__init__ ()
475523 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.CountVectorizer" ,
476524 self .uid )
477- self ._setDefault (minTF = 1.0 , minDF = 1.0 , vocabSize = 1 << 18 , binary = False )
478525 kwargs = self ._input_kwargs
479526 self .setParams (** kwargs )
480527
@@ -497,66 +544,61 @@ def setMinTF(self, value):
497544 """
498545 return self ._set (minTF = value )
499546
500- @since ("1.6.0" )
501- def getMinTF (self ):
502- """
503- Gets the value of minTF or its default value.
504- """
505- return self .getOrDefault (self .minTF )
506-
507547 @since ("1.6.0" )
508548 def setMinDF (self , value ):
509549 """
510550 Sets the value of :py:attr:`minDF`.
511551 """
512552 return self ._set (minDF = value )
513553
514- @since ("1.6.0" )
515- def getMinDF (self ):
516- """
517- Gets the value of minDF or its default value.
518- """
519- return self .getOrDefault (self .minDF )
520-
521554 @since ("1.6.0" )
522555 def setVocabSize (self , value ):
523556 """
524557 Sets the value of :py:attr:`vocabSize`.
525558 """
526559 return self ._set (vocabSize = value )
527560
528- @since ("1.6.0" )
529- def getVocabSize (self ):
530- """
531- Gets the value of vocabSize or its default value.
532- """
533- return self .getOrDefault (self .vocabSize )
534-
535561 @since ("2.0.0" )
536562 def setBinary (self , value ):
537563 """
538564 Sets the value of :py:attr:`binary`.
539565 """
540566 return self ._set (binary = value )
541567
542- @since ("2.0.0" )
543- def getBinary (self ):
544- """
545- Gets the value of binary or its default value.
546- """
547- return self .getOrDefault (self .binary )
548-
549568 def _create_model (self , java_model ):
550569 return CountVectorizerModel (java_model )
551570
552571
553- class CountVectorizerModel (JavaModel , JavaMLReadable , JavaMLWritable ):
572+ @inherit_doc
573+ class CountVectorizerModel (JavaModel , _CountVectorizerParams , JavaMLReadable , JavaMLWritable ):
554574 """
555575 Model fitted by :py:class:`CountVectorizer`.
556576
557577 .. versionadded:: 1.6.0
558578 """
559579
580+ @classmethod
581+ @since ("2.4.0" )
582+ def from_vocabulary (cls , vocabulary , inputCol , outputCol = None , minTF = None , binary = None ):
583+ """
584+ Construct the model directly from a vocabulary list of strings,
585+ requires an active SparkContext.
586+ """
587+ sc = SparkContext ._active_spark_context
588+ java_class = sc ._gateway .jvm .java .lang .String
589+ jvocab = CountVectorizerModel ._new_java_array (vocabulary , java_class )
590+ model = CountVectorizerModel ._create_from_java_class (
591+ "org.apache.spark.ml.feature.CountVectorizerModel" , jvocab )
592+ model .setInputCol (inputCol )
593+ if outputCol is not None :
594+ model .setOutputCol (outputCol )
595+ if minTF is not None :
596+ model .setMinTF (minTF )
597+ if binary is not None :
598+ model .setBinary (binary )
599+ model ._set (vocabSize = len (vocabulary ))
600+ return model
601+
560602 @property
561603 @since ("1.6.0" )
562604 def vocabulary (self ):
@@ -565,6 +607,20 @@ def vocabulary(self):
565607 """
566608 return self ._call_java ("vocabulary" )
567609
610+ @since ("2.4.0" )
611+ def setMinTF (self , value ):
612+ """
613+ Sets the value of :py:attr:`minTF`.
614+ """
615+ return self ._set (minTF = value )
616+
617+ @since ("2.4.0" )
618+ def setBinary (self , value ):
619+ """
620+ Sets the value of :py:attr:`binary`.
621+ """
622+ return self ._set (binary = value )
623+
568624
569625@inherit_doc
570626class DCT (JavaTransformer , HasInputCol , HasOutputCol , JavaMLReadable , JavaMLWritable ):
0 commit comments