diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 11864cb8f439..adcb2cca9211 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.feature +import java.util.Locale + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -73,7 +75,23 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false) + /** + * Locale for doing a case sensitive comparison + * Default: English locale ("en") + * @see [[http://www.localeplanet.com/java/]] + * @group param + */ + val locale: Param[String] = new Param[String](this, "locale", + "locale for doing a case sensitive comparison") + + /** @group setParam */ + def setLocale(value: String): this.type = set(locale, value) + + /** @group getParam */ + def getLocale: String = $(locale) + + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive -> false, locale -> "en") @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -81,14 +99,14 @@ class StopWordsRemover(override val uid: String) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet udf { terms: Seq[String] => - terms.filter(s => !stopWordsSet.contains(s)) + terms.filterNot(stopWordsSet.contains) } } else { - // TODO: support user locale (SPARK-15064) - val toLower = (s: String) => if (s != null) s.toLowerCase else s + val loadedLocale = StopWordsRemover.loadLocale($(locale)) + val toLower = (s: String) => if (s != null) s.toLowerCase(loadedLocale) else s val lowerStopWords = $(stopWords).map(toLower(_)).toSet udf { terms: Seq[String] => - terms.filter(s => !lowerStopWords.contains(toLower(s))) + terms.filterNot(term => lowerStopWords.contains(toLower(term))) } } val metadata = outputSchema($(outputCol)).metadata @@ -109,6 +127,7 @@ class StopWordsRemover(override val uid: String) object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { private[feature] + def loadLocale(value : String): java.util.Locale = new Locale(value) val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german", "hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 125ad02ebcc0..6ff6c5fc4e55 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -98,6 +98,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) + .setLocale("tr") val dataSet = spark.createDataFrame(Seq( (Seq("acaba", "ama", "biri"), Seq()), (Seq("hep", "her", "scala"), Seq("scala")) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 606a6e7c22b4..4505ccb91e06 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1736,25 +1736,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words", typeConverter=TypeConverters.toBoolean) + locale = Param(Params._dummy(), "locale", "locale for doing a case sensitive comparison", + typeConverter=TypeConverters.toString) @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False, locale="en"): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=false, locale="en") """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False) + caseSensitive=False, locale="en") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False, locale="en"): """ - setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=false, locale="en") Sets params for this StopWordRemover. """ kwargs = self.setParams._input_kwargs @@ -1788,6 +1794,20 @@ def getCaseSensitive(self): """ return self.getOrDefault(self.caseSensitive) + @since("2.0.0") + def setLocale(self, value): + """ + Sets the value of :py:attr:`locale`. + """ + return self._set(locale=value) + + @since("2.0.0") + def getLocale(self): + """ + Gets the value of :py:attr:`locale`. + """ + return self.getOrDefault(self.locale) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 49d3a4a332fd..6ee8451640ba 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -433,6 +433,13 @@ def test_stopwordsremover(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BİRİ"] + dataset = sqlContext.createDataFrame([Row(input=["biri"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): dataset = self.spark.createDataFrame([