-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary list #16770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
01e5a4b
e94dde3
8860641
5220ff1
7e05da4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,12 +19,12 @@ | |
| if sys.version > '3': | ||
| basestring = str | ||
|
|
||
| from pyspark import since, keyword_only | ||
| from pyspark import since, keyword_only, SparkContext | ||
| from pyspark.rdd import ignore_unicode_prefix | ||
| from pyspark.ml.linalg import _convert_to_vector | ||
| from pyspark.ml.param.shared import * | ||
| from pyspark.ml.util import JavaMLReadable, JavaMLWritable | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm | ||
| from pyspark.ml.common import inherit_doc | ||
|
|
||
| __all__ = ['Binarizer', | ||
|
|
@@ -403,8 +403,69 @@ def getSplits(self): | |
| return self.getOrDefault(self.splits) | ||
|
|
||
|
|
||
| class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): | ||
| """ | ||
| Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`. | ||
| """ | ||
|
|
||
| minTF = Param( | ||
| Params._dummy(), "minTF", "Filter to ignore rare words in" + | ||
| " a document. For each document, terms with frequency/count less than the given" + | ||
| " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + | ||
| " times the term must appear in the document); if this is a double in [0,1), then this " + | ||
| "specifies a fraction (out of the document's token count). Note that the parameter is " + | ||
| "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", | ||
| typeConverter=TypeConverters.toFloat) | ||
| minDF = Param( | ||
| Params._dummy(), "minDF", "Specifies the minimum number of" + | ||
| " different documents a term must appear in to be included in the vocabulary." + | ||
| " If this is an integer >= 1, this specifies the number of documents the term must" + | ||
| " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + | ||
| " Default 1.0", typeConverter=TypeConverters.toFloat) | ||
| vocabSize = Param( | ||
| Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", | ||
| typeConverter=TypeConverters.toInt) | ||
| binary = Param( | ||
| Params._dummy(), "binary", "Binary toggle to control the output vector values." + | ||
| " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + | ||
| " for discrete probabilistic models that model binary events rather than integer counts." + | ||
| " Default False", typeConverter=TypeConverters.toBoolean) | ||
|
|
||
| def __init__(self, *args): | ||
| super(_CountVectorizerParams, self).__init__(*args) | ||
| self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) | ||
|
|
||
| @since("1.6.0") | ||
| def getMinTF(self): | ||
| """ | ||
| Gets the value of minTF or its default value. | ||
| """ | ||
| return self.getOrDefault(self.minTF) | ||
|
|
||
| @since("1.6.0") | ||
| def getMinDF(self): | ||
| """ | ||
| Gets the value of minDF or its default value. | ||
| """ | ||
| return self.getOrDefault(self.minDF) | ||
|
|
||
| @since("1.6.0") | ||
| def getVocabSize(self): | ||
| """ | ||
| Gets the value of vocabSize or its default value. | ||
| """ | ||
| return self.getOrDefault(self.vocabSize) | ||
|
|
||
| @since("2.0.0") | ||
| def getBinary(self): | ||
| """ | ||
| Gets the value of binary or its default value. | ||
| """ | ||
| return self.getOrDefault(self.binary) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): | ||
| class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): | ||
| """ | ||
| Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. | ||
|
|
||
|
|
@@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, | |
| >>> loadedModel = CountVectorizerModel.load(modelPath) | ||
| >>> loadedModel.vocabulary == model.vocabulary | ||
| True | ||
| >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"], | ||
| ... inputCol="raw", outputCol="vectors") | ||
| >>> fromVocabModel.transform(df).show(truncate=False) | ||
| +-----+---------------+-------------------------+ | ||
| |label|raw |vectors | | ||
| +-----+---------------+-------------------------+ | ||
| |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| | ||
| |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| | ||
| +-----+---------------+-------------------------+ | ||
| ... | ||
|
|
||
| .. versionadded:: 1.6.0 | ||
| """ | ||
|
|
||
| minTF = Param( | ||
| Params._dummy(), "minTF", "Filter to ignore rare words in" + | ||
| " a document. For each document, terms with frequency/count less than the given" + | ||
| " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + | ||
| " times the term must appear in the document); if this is a double in [0,1), then this " + | ||
| "specifies a fraction (out of the document's token count). Note that the parameter is " + | ||
| "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", | ||
| typeConverter=TypeConverters.toFloat) | ||
| minDF = Param( | ||
| Params._dummy(), "minDF", "Specifies the minimum number of" + | ||
| " different documents a term must appear in to be included in the vocabulary." + | ||
| " If this is an integer >= 1, this specifies the number of documents the term must" + | ||
| " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + | ||
| " Default 1.0", typeConverter=TypeConverters.toFloat) | ||
| vocabSize = Param( | ||
| Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", | ||
| typeConverter=TypeConverters.toInt) | ||
| binary = Param( | ||
| Params._dummy(), "binary", "Binary toggle to control the output vector values." + | ||
| " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + | ||
| " for discrete probabilistic models that model binary events rather than integer counts." + | ||
| " Default False", typeConverter=TypeConverters.toBoolean) | ||
|
|
||
| @keyword_only | ||
| def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, | ||
| outputCol=None): | ||
|
|
@@ -474,7 +522,6 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC | |
| super(CountVectorizer, self).__init__() | ||
| self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", | ||
| self.uid) | ||
| self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) | ||
| kwargs = self._input_kwargs | ||
| self.setParams(**kwargs) | ||
|
|
||
|
|
@@ -497,66 +544,61 @@ def setMinTF(self, value): | |
| """ | ||
| return self._set(minTF=value) | ||
|
|
||
| @since("1.6.0") | ||
| def getMinTF(self): | ||
| """ | ||
| Gets the value of minTF or its default value. | ||
| """ | ||
| return self.getOrDefault(self.minTF) | ||
|
|
||
| @since("1.6.0") | ||
| def setMinDF(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`minDF`. | ||
| """ | ||
| return self._set(minDF=value) | ||
|
|
||
| @since("1.6.0") | ||
| def getMinDF(self): | ||
| """ | ||
| Gets the value of minDF or its default value. | ||
| """ | ||
| return self.getOrDefault(self.minDF) | ||
|
|
||
| @since("1.6.0") | ||
| def setVocabSize(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`vocabSize`. | ||
| """ | ||
| return self._set(vocabSize=value) | ||
|
|
||
| @since("1.6.0") | ||
| def getVocabSize(self): | ||
| """ | ||
| Gets the value of vocabSize or its default value. | ||
| """ | ||
| return self.getOrDefault(self.vocabSize) | ||
|
|
||
| @since("2.0.0") | ||
| def setBinary(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`binary`. | ||
| """ | ||
| return self._set(binary=value) | ||
|
|
||
| @since("2.0.0") | ||
| def getBinary(self): | ||
| """ | ||
| Gets the value of binary or its default value. | ||
| """ | ||
| return self.getOrDefault(self.binary) | ||
|
|
||
| def _create_model(self, java_model): | ||
| return CountVectorizerModel(java_model) | ||
|
|
||
|
|
||
| class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): | ||
| @inherit_doc | ||
| class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): | ||
| """ | ||
| Model fitted by :py:class:`CountVectorizer`. | ||
|
|
||
| .. versionadded:: 1.6.0 | ||
| """ | ||
|
|
||
| @classmethod | ||
| @since("2.4.0") | ||
| def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): | ||
| """ | ||
| Construct the model directly from a vocabulary list of strings, | ||
| requires an active SparkContext. | ||
| """ | ||
| sc = SparkContext._active_spark_context | ||
| java_class = sc._gateway.jvm.java.lang.String | ||
| jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class) | ||
| model = CountVectorizerModel._create_from_java_class( | ||
| "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) | ||
| model.setInputCol(inputCol) | ||
| if outputCol is not None: | ||
| model.setOutputCol(outputCol) | ||
| if minTF is not None: | ||
| model.setMinTF(minTF) | ||
| if binary is not None: | ||
| model.setBinary(binary) | ||
| model._set(vocabSize=len(vocabulary)) | ||
| return model | ||
|
|
||
| @property | ||
| @since("1.6.0") | ||
| def vocabulary(self): | ||
|
|
@@ -565,6 +607,20 @@ def vocabulary(self): | |
| """ | ||
| return self._call_java("vocabulary") | ||
|
|
||
| @since("2.4.0") | ||
| def setMinTF(self, value): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're going to have the setters in both the model and the estimator maybe we should consider putting it in the shared params class?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree but I was trying to match the Scala API. My only thought is it was done this way to leave it up to the implementations if they allow setting the params. What do you think?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable to me. |
||
| """ | ||
| Sets the value of :py:attr:`minTF`. | ||
| """ | ||
| return self._set(minTF=value) | ||
|
|
||
| @since("2.4.0") | ||
| def setBinary(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`binary`. | ||
| """ | ||
| return self._set(binary=value) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -640,6 +640,34 @@ def test_count_vectorizer_with_binary(self): | |
| feature, expected = r | ||
| self.assertEqual(feature, expected) | ||
|
|
||
| def test_count_vectorizer_from_vocab(self): | ||
| model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good first test, I'd love to also see it with empty vocab, and also one that uses the default values.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, good idea, I'll add those
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| outputCol="features", minTF=2) | ||
| self.assertEqual(model.vocabulary, ["a", "b", "c"]) | ||
| self.assertEqual(model.getMinTF(), 2) | ||
|
|
||
| dataset = self.spark.createDataFrame([ | ||
| (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), | ||
| (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), | ||
| (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) | ||
|
|
||
| transformed_list = model.transform(dataset).select("features", "expected").collect() | ||
|
|
||
| for r in transformed_list: | ||
| feature, expected = r | ||
| self.assertEqual(feature, expected) | ||
|
|
||
| # Test an empty vocabulary | ||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): | ||
| CountVectorizerModel.from_vocabulary([], inputCol="words") | ||
|
|
||
| # Test model with default settings can transform | ||
| model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") | ||
| transformed_list = model_default.transform(dataset)\ | ||
| .select(model_default.getOrDefault(model_default.outputCol)).collect() | ||
| self.assertEqual(len(transformed_list), 3) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The doctest uses default values for all params except
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sgtm |
||
|
|
||
| def test_rformula_force_index_label(self): | ||
| df = self.spark.createDataFrame([ | ||
| (1.0, 1.0, "a"), | ||
|
|
@@ -1980,8 +2008,8 @@ def test_java_params(self): | |
| pyspark.ml.regression] | ||
| for module in modules: | ||
| for name, cls in inspect.getmembers(module, inspect.isclass): | ||
| if not name.endswith('Model') and issubclass(cls, JavaParams)\ | ||
| and not inspect.isabstract(cls): | ||
| if not name.endswith('Model') and not name.endswith('Params')\ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to make sure I've understood whats happening here, were avoiding doing the default params test on non-concrete classes like the base params shared between the model and the estimator and instead testing just the model and estimator on their own right?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's pretty much right but this is only checking estimators and skips models also. We should have an explicit check for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable. I look forward to us automatically catching models with missing params eventually as well.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm helping get |
||
| and issubclass(cls, JavaParams) and not inspect.isabstract(cls): | ||
| # NOTE: disable check_params_exist until there is parity with Scala API | ||
| ParamTests.check_params(self, cls(), check_params_exist=False) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for _set rather than set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only difference is
setchecks to make sure the param is valid, which isn't really needed since this is internal.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm