diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 944739bcd207..edc7ca6f5146 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1362,8 +1362,8 @@ def replace(self, to_replace, value=None, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. - Values to_replace and value should contain either all numerics, all booleans, - or all strings. When replacing, the new value will be cast + Values to_replace and value must have the same type and can only be numerics, booleans, + or strings. Value can have None. When replacing, the new value will be cast to the type of the existing column. For numeric replacements all values to be replaced should have unique floating point representation. In case of conflicts (for example with `{42: -1, 42.0: 1}`) @@ -1373,8 +1373,8 @@ def replace(self, to_replace, value=None, subset=None): Value to be replaced. If the value is a dict, then `value` is ignored and `to_replace` must be a mapping between a value and a replacement. - :param value: int, long, float, string, or list. - The replacement value must be an int, long, float, or string. If `value` is a + :param value: bool, int, long, float, string, list or None. + The replacement value must be a bool, int, long, float, string or None. If `value` is a list, `value` should be of the same length and type as `to_replace`. If `value` is a scalar and `to_replace` is a sequence, then `value` is used as a replacement for each item in `to_replace`. @@ -1393,6 +1393,16 @@ def replace(self, to_replace, value=None, subset=None): |null| null| null| +----+------+-----+ + >>> df4.na.replace('Alice', None).show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80|null| + | 5| null| Bob| + |null| null| Tom| + |null| null|null| + +----+------+----+ + >>> df4.na.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() +----+------+----+ | age|height|name| @@ -1425,12 +1435,13 @@ def all_of_(xs): valid_types = (bool, float, int, long, basestring, list, tuple) if not isinstance(to_replace, valid_types + (dict, )): raise ValueError( - "to_replace should be a float, int, long, string, list, tuple, or dict. " + "to_replace should be a bool, float, int, long, string, list, tuple, or dict. " "Got {0}".format(type(to_replace))) - if not isinstance(value, valid_types) and not isinstance(to_replace, dict): + if not isinstance(value, valid_types) and value is not None \ + and not isinstance(to_replace, dict): raise ValueError("If to_replace is not a dict, value should be " - "a float, int, long, string, list, or tuple. " + "a bool, float, int, long, string, list, tuple or None. " "Got {0}".format(type(value))) if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): @@ -1446,21 +1457,21 @@ def all_of_(xs): if isinstance(to_replace, (float, int, long, basestring)): to_replace = [to_replace] - if isinstance(value, (float, int, long, basestring)): - value = [value for _ in range(len(to_replace))] - if isinstance(to_replace, dict): rep_dict = to_replace if value is not None: warnings.warn("to_replace is a dict and value is not None. value will be ignored.") else: + if isinstance(value, (float, int, long, basestring)) or value is None: + value = [value for _ in range(len(to_replace))] rep_dict = dict(zip(to_replace, value)) if isinstance(subset, basestring): subset = [subset] - # Verify we were not passed in mixed type generics." - if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values()) + # Verify we were not passed in mixed type generics. + if not any(all_of_type(rep_dict.keys()) + and all_of_type(x for x in rep_dict.values() if x is not None) for all_of_type in [all_of_bool, all_of_str, all_of_numeric]): raise ValueError("Mixed type replacements are not supported") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cfd9c558ff67..cf2c473a1645 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1964,6 +1964,21 @@ def test_replace(self): .replace(False, True).first()) self.assertTupleEqual(row, (True, True)) + # replace list while value is not given (default to None) + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first() + self.assertTupleEqual(row, (None, 10, 80.0)) + + # replace string with None and then drop None rows + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna() + self.assertEqual(row.count(), 0) + + # replace with number and None + row = self.spark.createDataFrame( + [(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first() + self.assertTupleEqual(row, (u'Alice', 20, None)) + # should fail if subset is not list, tuple or None with self.assertRaises(ValueError): self.spark.createDataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 871fff71e553..e068df3586f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -260,9 +260,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * If `col` is "*", then the replacement is applied on all string columns or numeric columns. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -277,8 +274,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param col name of the column to apply the value replacement - * @param replacement value replacement map, as explained above + * @param col name of the column to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -288,8 +288,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Replaces values matching keys in `replacement` map with the corresponding values. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. * * {{{ * import com.google.common.collect.ImmutableMap; @@ -301,8 +299,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement - * @param replacement value replacement map, as explained above + * @param cols list of columns to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -312,10 +313,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles, strings or booleans. - * If `col` is "*", - * then the replacement is applied on all string columns , numeric columns or boolean columns. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". @@ -328,8 +325,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * - * @param col name of the column to apply the value replacement - * @param replacement value replacement map, as explained above + * @param col name of the column to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -343,8 +343,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Replaces values matching keys in `replacement` map. - * Key and value of `replacement` map must have the same type, and - * can only be doubles , strings or booleans. * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". @@ -354,8 +352,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * - * @param cols list of columns to apply the value replacement - * @param replacement value replacement map, as explained above + * @param cols list of columns to apply the value replacement. If `col` is "*", + * replacement is applied on all string, numeric or boolean columns. + * @param replacement value replacement map. Key and value of `replacement` map must have + * the same type, and can only be doubles, strings or booleans. + * The map value can have nulls. * * @since 1.3.1 */ @@ -366,14 +367,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { return df } - // replacementMap is either Map[String, String] or Map[Double, Double] or Map[Boolean,Boolean] - val replacementMap: Map[_, _] = replacement.head._2 match { - case v: String => replacement - case v: Boolean => replacement - case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) } + // Convert the NumericType in replacement map to DoubleType, + // while leaving StringType, BooleanType and null untouched. + val replacementMap: Map[_, _] = replacement.map { + case (k, v: String) => (k, v) + case (k, v: Boolean) => (k, v) + case (k: String, null) => (k, null) + case (k: Boolean, null) => (k, null) + case (k, null) => (convertToDouble(k), null) + case (k, v) => (convertToDouble(k), convertToDouble(v)) } - // targetColumnType is either DoubleType or StringType or BooleanType + // targetColumnType is either DoubleType, StringType or BooleanType, + // depending on the type of first key in replacement map. + // Only fields of targetColumnType will perform replacement. val targetColumnType = replacement.head._1 match { case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType case _: jl.Boolean => BooleanType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 47c9ba5847a4..e6983b6be555 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -262,4 +262,47 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { assert(out1(4) === Row("Amy", null, null)) assert(out1(5) === Row(null, null, null)) } + + test("replace with null") { + val input = Seq[(String, java.lang.Double, java.lang.Boolean)]( + ("Bob", 176.5, true), + ("Alice", 164.3, false), + ("David", null, true) + ).toDF("name", "height", "married") + + // Replace String with String and null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> "Bravo", + "Alice" -> null + )), + Row("Bravo", 176.5, true) :: + Row(null, 164.3, false) :: + Row("David", null, true) :: Nil) + + // Replace Double with null + checkAnswer( + input.na.replace("height", Map[Any, Any]( + 164.3 -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", null, false) :: + Row("David", null, true) :: Nil) + + // Replace Boolean with null + checkAnswer( + input.na.replace("*", Map[Any, Any]( + false -> null + )), + Row("Bob", 176.5, true) :: + Row("Alice", 164.3, null) :: + Row("David", null, true) :: Nil) + + // Replace String with null and then drop rows containing null + checkAnswer( + input.na.replace("name", Map( + "Bob" -> null + )).na.drop("name" :: Nil).select("name"), + Row("Alice") :: Row("David") :: Nil) + } }