Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,14 +512,19 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
.. versionadded:: 1.3.0
"""

binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " +
"This is useful for discrete probabilistic models that model binary events " +
"rather than integer counts. Default False.",
typeConverter=TypeConverters.toBoolean)

@keyword_only
def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None):
"""
__init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
"""
super(HashingTF, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
self._setDefault(numFeatures=1 << 18)
self._setDefault(numFeatures=1 << 18, binary=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

Expand All @@ -533,6 +538,21 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

@since("2.0.0")
def setBinary(self, value):
"""
Sets the value of :py:attr:`binary`.
"""
self._paramMap[self.binary] = value
return self

@since("2.0.0")
def getBinary(self):
"""
Gets the value of binary or its default value.
"""
return self.getOrDefault(self.binary)


@inherit_doc
class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,25 @@ def test_logistic_regression_summary(self):
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)


class HashingTFTest(PySparkTestCase):

def test_apply_binary_term_freqs(self):
sqlContext = SQLContext(self.sc)

df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
n = 100
hashingTF = HashingTF()
hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
output = hashingTF.transform(df)
features = output.select("features").first().features.toArray()
expected = Vectors.sparse(n, {(ord("a") % n): 1.0,
(ord("b") % n): 1.0,
(ord("c") % n): 1.0}).toArray()
for i in range(0, n):
self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) +
": expected " + str(expected[i]) + ", got " + str(features[i]))


if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
Expand Down
13 changes: 12 additions & 1 deletion python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,17 @@ class HashingTF(object):
"""
def __init__(self, numFeatures=1 << 20):
self.numFeatures = numFeatures
self.binary = False

@since("2.0.0")
def setBinary(self, value):
"""
If True, term frequency vector will be binary such that non-zero
term counts will be set to 1
(default: False)
"""
self.binary = value
return self

@since('1.2.0')
def indexOf(self, term):
Expand All @@ -398,7 +409,7 @@ def transform(self, document):
freq = {}
for term in document:
i = self.indexOf(term)
freq[i] = freq.get(i, 0) + 1.0
freq[i] = 1.0 if self.binary else freq.get(i, 0) + 1.0
return Vectors.sparse(self.numFeatures, freq.items())


Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.mllib.feature import HashingTF
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
Expand Down Expand Up @@ -1583,6 +1584,21 @@ def test_als_ratings_id_long_error(self):
self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))


class HashingTFTest(MLlibTestCase):

def test_binary_term_freqs(self):
hashingTF = HashingTF(100).setBinary(True)
doc = "a a b c c c".split(" ")
n = hashingTF.numFeatures
output = hashingTF.transform(doc).toArray()
expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0,
hashingTF.indexOf("b"): 1.0,
hashingTF.indexOf("c"): 1.0}).toArray()
for i in range(0, n):
self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) +
": expected " + str(expected[i]) + ", got " + str(output[i]))


if __name__ == "__main__":
from pyspark.mllib.tests import *
if not _have_scipy:
Expand Down