Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.ml.feature

import java.util.Locale

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
Expand Down Expand Up @@ -73,22 +75,38 @@ class StopWordsRemover(override val uid: String)
/** @group getParam */
def getCaseSensitive: Boolean = $(caseSensitive)

setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false)
/**
* Locale for doing a case sensitive comparison
* Default: English locale ("en")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we list what're the available options, or provide some reference here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done

* @see [[http://www.localeplanet.com/java/]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please link to the official Java doc: https://docs.oracle.com/javase/8/docs/api/java/util/Locale.html or the Locale class.

* @group param
*/
val locale: Param[String] = new Param[String](this, "locale",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, shouldn't all this perhaps be linked to the stopwords set? if you loaded the French stopwords you'd want the French locale always?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but, How can we know that users loaded the French stopwords? User can load stopwords by
StopWordsRemover.loadDefaultStopWords("french")
and setting is
new StopWordsRemover().setStopWords(stopWords)
. Do you have any suggestion about that case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For supported languages, we can know the appropriate locale and maintain an internal mapping. So "french" is known to map to Locale.FRENCH. For loading an arbitrary list, we don't know, but you could provide an overload where you provide a Locale.

"locale for doing a case sensitive comparison")

/** @group setParam */
def setLocale(value: String): this.type = set(locale, value)
Copy link
Contributor

@hhbyyh hhbyyh May 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Myabe add parameter check here or in transformSchema, to help detect error before pipeline executes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.


/** @group getParam */
def getLocale: String = $(locale)

setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive -> false, locale -> "en")
Copy link
Contributor

@hhbyyh hhbyyh May 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing with EN, it perhaps better to use Locale.default (original behavior)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the English set is loaded by default the locale should match, rather than use the platform default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, in any event, if 'stopwords' is not set, English list will be loaded. I think, It is better to use English locale as default. If users want to change locale, they can simply change by setLocale.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, En Locale is better.


@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val t = if ($(caseSensitive)) {
val stopWordsSet = $(stopWords).toSet
udf { terms: Seq[String] =>
terms.filter(s => !stopWordsSet.contains(s))
terms.filterNot(stopWordsSet.contains)
}
} else {
// TODO: support user locale (SPARK-15064)
val toLower = (s: String) => if (s != null) s.toLowerCase else s
val loadedLocale = StopWordsRemover.loadLocale($(locale))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just new Locale($(locale))

val toLower = (s: String) => if (s != null) s.toLowerCase(loadedLocale) else s
val lowerStopWords = $(stopWords).map(toLower(_)).toSet
udf { terms: Seq[String] =>
terms.filter(s => !lowerStopWords.contains(toLower(s)))
terms.filterNot(term => lowerStopWords.contains(toLower(term)))
}
}
val metadata = outputSchema($(outputCol)).metadata
Expand All @@ -109,6 +127,7 @@ class StopWordsRemover(override val uid: String)
object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] {

private[feature]
def loadLocale(value : String): java.util.Locale = new Locale(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just new Locale(value). Shall we remove this method?

val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german",
"hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class StopWordsRemoverSuite
.setInputCol("raw")
.setOutputCol("filtered")
.setStopWords(stopWords)
.setLocale("tr")
Copy link
Contributor

@hhbyyh hhbyyh May 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something more specific to test that Locale setter is working. I would suggest adding a new ut.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't use special charset because of styles check, but I did it in Python's test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider to use // scalastyle:off as necessary

val dataSet = spark.createDataFrame(Seq(
(Seq("acaba", "ama", "biri"), Seq()),
(Seq("hep", "her", "scala"), Seq("scala"))
Expand Down
30 changes: 25 additions & 5 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,25 +1736,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
typeConverter=TypeConverters.toListString)
caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
"comparison over the stop words", typeConverter=TypeConverters.toBoolean)
locale = Param(Params._dummy(), "locale", "locale for doing a case sensitive comparison",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
def __init__(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=False, locale="en"):
"""
__init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
__init__(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=false, locale="en")
"""
super(StopWordsRemover, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover",
self.uid)
self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
caseSensitive=False)
caseSensitive=False, locale="en")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("1.6.0")
def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False):
def setParams(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=False, locale="en"):
"""
setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
setParams(self, inputCol=None, outputCol=None, stopWords=None,
caseSensitive=false, locale="en")
Sets params for this StopWordRemover.
"""
kwargs = self.setParams._input_kwargs
Expand Down Expand Up @@ -1788,6 +1794,20 @@ def getCaseSensitive(self):
"""
return self.getOrDefault(self.caseSensitive)

@since("2.0.0")
def setLocale(self, value):
"""
Sets the value of :py:attr:`locale`.
"""
return self._set(locale=value)

@since("2.0.0")
def getLocale(self):
"""
Gets the value of :py:attr:`locale`.
"""
return self.getOrDefault(self.locale)

@staticmethod
@since("2.0.0")
def loadDefaultStopWords(language):
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,13 @@ def test_stopwordsremover(self):
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])
# with locale
stopwords = ["BİRİ"]
dataset = sqlContext.createDataFrame([Row(input=["biri"])])
stopWordRemover.setStopWords(stopwords).setLocale("tr")
self.assertEqual(stopWordRemover.getStopWords(), stopwords)
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, [])

def test_count_vectorizer_with_binary(self):
dataset = self.spark.createDataFrame([
Expand Down