Skip to content

Commit da65f4b

Browse files
committed
Added class method to construct CountVectorizerModel from vocab, not yet working because missing param _copyValues from estimator to model
1 parent 57d70d2 commit da65f4b

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

python/pyspark/ml/feature.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
if sys.version > '3':
2020
basestring = str
2121

22-
from pyspark import since, keyword_only
22+
from pyspark import since, keyword_only, SparkContext
2323
from pyspark.rdd import ignore_unicode_prefix
2424
from pyspark.ml.linalg import _convert_to_vector
2525
from pyspark.ml.param.shared import *
@@ -254,6 +254,16 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
254254
>>> loadedModel = CountVectorizerModel.load(modelPath)
255255
>>> loadedModel.vocabulary == model.vocabulary
256256
True
257+
>>> fromVocabModel = CountVectorizerModel.fromVocabulary(model.vocabulary,
258+
... inputCol="raw", outputCol="vectors")
259+
>>> fromVocabModel.transform(df).show(truncate=False)
260+
+-----+---------------+-------------------------+
261+
|label|raw |vectors |
262+
+-----+---------------+-------------------------+
263+
|0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])|
264+
|1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
265+
+-----+---------------+-------------------------+
266+
...
257267
258268
.. versionadded:: 1.6.0
259269
"""
@@ -367,13 +377,30 @@ def _create_model(self, java_model):
367377
return CountVectorizerModel(java_model)
368378

369379

370-
class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
380+
class CountVectorizerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
371381
"""
372382
Model fitted by :py:class:`CountVectorizer`.
373383
374384
.. versionadded:: 1.6.0
375385
"""
376386

387+
@classmethod
388+
@since("2.2.0")
389+
def fromVocabulary(cls, vocab, inputCol, outputCol=None):
390+
"""
391+
Construct the model directly from a vocabulary list, requires
392+
an active SparkContext.
393+
"""
394+
sc = SparkContext._active_spark_context
395+
java_class = sc._gateway.jvm.java.lang.String
396+
jvocab = CountVectorizerModel._new_java_array(vocab, java_class)
397+
model = CountVectorizerModel._create_from_java_class(
398+
"org.apache.spark.ml.feature.CountVectorizerModel", jvocab)
399+
model.setInputCol(inputCol)
400+
if outputCol is not None:
401+
model.setOutputCol(outputCol)
402+
return model
403+
377404
@property
378405
@since("1.6.0")
379406
def vocabulary(self):

0 commit comments

Comments
 (0)