|
19 | 19 | if sys.version > '3': |
20 | 20 | basestring = str |
21 | 21 |
|
22 | | -from pyspark import since, keyword_only |
| 22 | +from pyspark import since, keyword_only, SparkContext |
23 | 23 | from pyspark.rdd import ignore_unicode_prefix |
24 | 24 | from pyspark.ml.linalg import _convert_to_vector |
25 | 25 | from pyspark.ml.param.shared import * |
@@ -254,6 +254,16 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, |
254 | 254 | >>> loadedModel = CountVectorizerModel.load(modelPath) |
255 | 255 | >>> loadedModel.vocabulary == model.vocabulary |
256 | 256 | True |
| 257 | + >>> fromVocabModel = CountVectorizerModel.fromVocabulary(model.vocabulary, |
| 258 | + ... inputCol="raw", outputCol="vectors") |
| 259 | + >>> fromVocabModel.transform(df).show(truncate=False) |
| 260 | + +-----+---------------+-------------------------+ |
| 261 | + |label|raw |vectors | |
| 262 | + +-----+---------------+-------------------------+ |
| 263 | + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| |
| 264 | + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| |
| 265 | + +-----+---------------+-------------------------+ |
| 266 | + ... |
257 | 267 |
|
258 | 268 | .. versionadded:: 1.6.0 |
259 | 269 | """ |
@@ -367,13 +377,30 @@ def _create_model(self, java_model): |
367 | 377 | return CountVectorizerModel(java_model) |
368 | 378 |
|
369 | 379 |
|
370 | | -class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): |
| 380 | +class CountVectorizerModel(JavaModel, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): |
371 | 381 | """ |
372 | 382 | Model fitted by :py:class:`CountVectorizer`. |
373 | 383 |
|
374 | 384 | .. versionadded:: 1.6.0 |
375 | 385 | """ |
376 | 386 |
|
| 387 | + @classmethod |
| 388 | + @since("2.2.0") |
| 389 | + def fromVocabulary(cls, vocab, inputCol, outputCol=None): |
| 390 | + """ |
| 391 | + Construct the model directly from a vocabulary list, requires |
| 392 | + an active SparkContext. |
| 393 | + """ |
| 394 | + sc = SparkContext._active_spark_context |
| 395 | + java_class = sc._gateway.jvm.java.lang.String |
| 396 | + jvocab = CountVectorizerModel._new_java_array(vocab, java_class) |
| 397 | + model = CountVectorizerModel._create_from_java_class( |
| 398 | + "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) |
| 399 | + model.setInputCol(inputCol) |
| 400 | + if outputCol is not None: |
| 401 | + model.setOutputCol(outputCol) |
| 402 | + return model |
| 403 | + |
377 | 404 | @property |
378 | 405 | @since("1.6.0") |
379 | 406 | def vocabulary(self): |
|
0 commit comments