From 2653750762c12cf5d17b8a2ac2a7ee9f8d55bfec Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Thu, 8 Dec 2016 21:22:31 -0800 Subject: [PATCH 1/7] [SPARK-14932][SQL] Allow DataFrame.replace() to replace values with None --- python/pyspark/sql/dataframe.py | 22 ++++++++++++++----- .../spark/sql/DataFrameNaFunctions.scala | 13 ++++++----- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b9d90384e3e2..b6997127d4fe 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1277,10 +1277,10 @@ def replace(self, to_replace, value, subset=None): If the value is a dict, then `value` is ignored and `to_replace` must be a mapping from column name (string) to replacement value. The value to be replaced must be an int, long, float, or string. - :param value: int, long, float, string, or list. + :param value: int, long, float, string, list or None. Value to use to replace holes. - The replacement value must be an int, long, float, or string. If `value` is a - list or tuple, `value` should be of the same length with `to_replace`. + The replacement value must be an int, long, float, string or None. If `value` + is a list or tuple, `value` should be of the same length with `to_replace`. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, @@ -1296,6 +1296,16 @@ def replace(self, to_replace, value, subset=None): |null| null| null| +----+------+-----+ + >>> df4.na.replace('Alice', None).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1310,8 +1320,8 @@ def replace(self, to_replace, value, subset=None): raise ValueError( "to_replace should be a float, int, long, string, list, tuple, or dict") - if not isinstance(value, (float, int, long, basestring, list, tuple)): - raise ValueError("value should be a float, int, long, string, list, or tuple") + if value is not None and not isinstance(value, (float, int, long, basestring, list, tuple)): + raise ValueError("value should be a float, int, long, string, list, tuple or None") rep_dict = dict() @@ -1328,7 +1338,7 @@ def replace(self, to_replace, value, subset=None): if len(to_replace) != len(value): raise ValueError("to_replace and value lists should be of the same length") rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): + elif isinstance(to_replace, list) and (value is None or isinstance(value, (float, int, long, basestring))): rep_dict = dict([(tr, value) for tr in to_replace]) elif isinstance(to_replace, dict): rep_dict = to_replace diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 28820681cd3a..777e0ca49dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -342,11 +342,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] - val replacementMap: Map[_, _] = replacement.head._2 match { - case v: String => replacement - case v: Boolean => replacement - case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } - } + val replacementMap: Map[_, _] = + if (replacement.head._2 == null) + replacement + else replacement.head._2 match { + case v: String => replacement + case v: Boolean => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } // targetColumnType is either DoubleType or StringType or BooleanType val targetColumnType = replacement.head._1 match { From 2eac8b9070f12fd7dade103857682669d3017587 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 9 Dec 2016 09:35:35 -0800 Subject: [PATCH 2/7] Scala test for df.replace with null --- .../spark/sql/DataFrameNaFunctionsSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index fd829846ac33..79a5aaca140f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -208,16 +208,16 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out(4) === Row("Amy", null, null)) assert(out(5) === Row(null, null, null)) - // Replace only the age column - val out1 = input.na.replace("age", Map( - 16 -> 61, - 60 -> 6, - 164.3 -> 461.3 // Alice is really tall + // Replace only the name column + val out1 = input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> "Jessie", + "David" -> null )).collect() - assert(out1(0) === Row("Bob", 61, 176.5)) - assert(out1(1) === Row("Alice", null, 164.3)) - assert(out1(2) === Row("David", 6, null)) + assert(out1(0) === Row("Bravo", 16, 176.5)) + assert(out1(1) === Row("Jessie", null, 164.3)) + assert(out1(2) === Row(null, 60, null)) assert(out1(3).get(2).asInstanceOf[Double].isNaN) assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) From 79492924bfad40f36aa132e58376da98b0727eac Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 9 Dec 2016 09:44:39 -0800 Subject: [PATCH 3/7] Use pattern matching for null case --- .../scala/org/apache/spark/sql/DataFrameNaFunctions.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 777e0ca49dd8..687e588f8450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -342,10 +342,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] - val replacementMap: Map[_, _] = - if (replacement.head._2 == null) - replacement - else replacement.head._2 match { + val replacementMap: Map[_, _] = replacement.head._2 match { + case null => replacement case v: String => replacement case v: Boolean => replacement case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } From 0b15c8f1d64754432c6585bb88e20d17738d4bbd Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 9 Dec 2016 09:45:51 -0800 Subject: [PATCH 4/7] Fix indentation --- .../org/apache/spark/sql/DataFrameNaFunctions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 687e588f8450..46e14bdbf949 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -343,11 +343,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] val replacementMap: Map[_, _] = replacement.head._2 match { - case null => replacement - case v: String => replacement - case v: Boolean => replacement - case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } - } + case null => replacement + case v: String => replacement + case v: Boolean => replacement + case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + } // targetColumnType is either DoubleType or StringType or BooleanType val targetColumnType = replacement.head._1 match { From 2c532c3781087ec8f0c36d5176837c9de568d7ec Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Sat, 10 Dec 2016 09:42:44 -0800 Subject: [PATCH 5/7] Fix Python style check --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b6997127d4fe..7fcf13700c16 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1338,7 +1338,8 @@ def replace(self, to_replace, value, subset=None): if len(to_replace) != len(value): raise ValueError("to_replace and value lists should be of the same length") rep_dict = dict(zip(to_replace, value)) - elif isinstance(to_replace, list) and (value is None or isinstance(value, (float, int, long, basestring))): + elif (isinstance(to_replace, list) and + (value is None or isinstance(value, (float, int, long, basestring)))): rep_dict = dict([(tr, value) for tr in to_replace]) elif isinstance(to_replace, dict): rep_dict = to_replace From 43fb6bd56802f2c20cdd28f7ea384e472183cdc4 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Tue, 23 May 2017 17:46:19 -0700 Subject: [PATCH 6/7] Improve scala doc and pyspark test --- python/pyspark/sql/tests.py | 5 +++++ .../org/apache/spark/sql/DataFrameNaFunctions.scala | 9 ++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index acea9113ee85..509463837a7a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1851,6 +1851,11 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) + # replace with None + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace((10, 80), None).first() + self.assertTupleEqual(row, (u'Alice', None, None)) + # should fail if subset is not list, tuple or None with self.assertRaises(ValueError): self.spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 6e7e1cd17bca..8e2b01417a6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -319,8 +319,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles , strings or booleans. + * Key and value of `replacement` map must satisfy one of: + * 1. keys are String, values are mix of String and null + * 2. keys are Boolean, values are mix of Boolean and null + * 3. keys are Double, values are either all Double or all null * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -342,7 +344,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] + // replacementMap is either Map[String, String], Map[Double, Double], Map[Boolean,Boolean] + // or value being null val replacementMap: Map[_, _] = replacement.head._2 match { case null => replacement case v: String => replacement From b5424d9fea56d2e0fb57ebc27d3d35054da6d22b Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Tue, 23 May 2017 19:56:07 -0700 Subject: [PATCH 7/7] Fix python3 dict.values() syntax --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f2675fd7339f..b57bb97e4e84 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1436,7 +1436,7 @@ def all_of_(xs): # Verify we were not passed in mixed type generics." if not any(all_of_type(rep_dict.keys()) and (all_of_type(rep_dict.values()) - or rep_dict.values().count(None) == len(rep_dict)) + or list(rep_dict.values()).count(None) == len(rep_dict)) for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported")