Skip to content

Commit 351be99

Browse files
committed
More tests, better comments
1 parent 2946659 commit 351be99

File tree

5 files changed

+58
-36
lines changed

5 files changed

+58
-36
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,20 +1457,19 @@ def all_of_(xs):
14571457
if isinstance(to_replace, (float, int, long, basestring)):
14581458
to_replace = [to_replace]
14591459

1460-
if isinstance(value, (float, int, long, basestring)) or value is None:
1461-
value = [value for _ in range(len(to_replace))]
1462-
14631460
if isinstance(to_replace, dict):
14641461
rep_dict = to_replace
14651462
if value is not None:
14661463
warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
14671464
else:
1465+
if isinstance(value, (float, int, long, basestring)) or value is None:
1466+
value = [value for _ in range(len(to_replace))]
14681467
rep_dict = dict(zip(to_replace, value))
14691468

14701469
if isinstance(subset, basestring):
14711470
subset = [subset]
14721471

1473-
# Verify we were not passed in mixed type generics."
1472+
# Verify we were not passed in mixed type generics.
14741473
if not any(all_of_type(rep_dict.keys())
14751474
and all_of_type(x for x in rep_dict.values() if x is not None)
14761475
for all_of_type in [all_of_bool, all_of_str, all_of_numeric]):

python/pyspark/sql/tests.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,12 +1964,17 @@ def test_replace(self):
19641964
.replace(False, True).first())
19651965
self.assertTupleEqual(row, (True, True))
19661966

1967-
# replace with None
1967+
# replace list while value is not given (default to None)
19681968
row = self.spark.createDataFrame(
1969-
[(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).first()
1969+
[(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
19701970
self.assertTupleEqual(row, (None, 10, 80.0))
19711971

1972-
# replace with numerics and None
1972+
# replace string with None and then drop None rows
1973+
row = self.spark.createDataFrame(
1974+
[(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
1975+
self.assertEqual(row.count(), 0)
1976+
1977+
# replace with number and None
19731978
row = self.spark.createDataFrame(
19741979
[(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
19751980
self.assertTupleEqual(row, (u'Alice', 20, None))

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ class DataTypeSuite extends SparkFunSuite {
145145
val message = intercept[SparkException] {
146146
left.merge(right)
147147
}.getMessage
148-
assert(message === "Failed to merge fields 'b' and 'b'. " +
149-
"Failed to merge incompatible data types FloatType and LongType")
148+
assert(message.equals("Failed to merge fields 'b' and 'b'. " +
149+
"Failed to merge incompatible data types FloatType and LongType"))
150150
}
151151

152152
test("existsRecursively") {

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
260260

261261
/**
262262
* Replaces values matching keys in `replacement` map with the corresponding values.
263-
* Key and value of `replacement` map must have the same type, and
264-
* can only be doubles, strings or booleans.
265-
* `replacement` map value can have null.
266-
* If `col` is "*", then the replacement is applied on all string columns or numeric columns.
267263
*
268264
* {{{
269265
* import com.google.common.collect.ImmutableMap;
@@ -278,8 +274,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
278274
* df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
279275
* }}}
280276
*
281-
* @param col name of the column to apply the value replacement
282-
* @param replacement value replacement map, as explained above
277+
* @param col name of the column to apply the value replacement. If `col` is "*",
278+
* replacement is applied on all string, numeric or boolean columns.
279+
* @param replacement value replacement map. Key and value of `replacement` map must have
280+
* the same type, and can only be doubles, strings or booleans.
281+
* The map value can have nulls.
283282
*
284283
* @since 1.3.1
285284
*/
@@ -289,9 +288,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
289288

290289
/**
291290
* Replaces values matching keys in `replacement` map with the corresponding values.
292-
* Key and value of `replacement` map must have the same type, and
293-
* can only be doubles, strings or booleans.
294-
* `replacement` map value can have null.
295291
*
296292
* {{{
297293
* import com.google.common.collect.ImmutableMap;
@@ -303,8 +299,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
303299
* df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
304300
* }}}
305301
*
306-
* @param cols list of columns to apply the value replacement
307-
* @param replacement value replacement map, as explained above
302+
* @param cols list of columns to apply the value replacement. If `col` is "*",
303+
* replacement is applied on all string, numeric or boolean columns.
304+
* @param replacement value replacement map. Key and value of `replacement` map must have
305+
* the same type, and can only be doubles, strings or booleans.
306+
* The map value can have nulls.
308307
*
309308
* @since 1.3.1
310309
*/
@@ -314,11 +313,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
314313

315314
/**
316315
* (Scala-specific) Replaces values matching keys in `replacement` map.
317-
* Key and value of `replacement` map must have the same type, and
318-
* can only be doubles, strings or booleans.
319-
* `replacement` map value can have null.
320-
* If `col` is "*",
321-
* then the replacement is applied on all string columns , numeric columns or boolean columns.
322316
*
323317
* {{{
324318
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
@@ -331,8 +325,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
331325
* df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
332326
* }}}
333327
*
334-
* @param col name of the column to apply the value replacement
335-
* @param replacement value replacement map, as explained above
328+
* @param col name of the column to apply the value replacement. If `col` is "*",
329+
* replacement is applied on all string, numeric or boolean columns.
330+
* @param replacement value replacement map. Key and value of `replacement` map must have
331+
* the same type, and can only be doubles, strings or booleans.
332+
* The map value can have nulls.
336333
*
337334
* @since 1.3.1
338335
*/
@@ -346,9 +343,6 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
346343

347344
/**
348345
* (Scala-specific) Replaces values matching keys in `replacement` map.
349-
* Key and value of `replacement` map must have the same type, and
350-
* can only be doubles, strings or booleans.
351-
* `replacement` map value can have null.
352346
*
353347
* {{{
354348
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
@@ -358,8 +352,11 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
358352
* df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"));
359353
* }}}
360354
*
361-
* @param cols list of columns to apply the value replacement
362-
* @param replacement value replacement map, as explained above
355+
* @param cols list of columns to apply the value replacement. If `col` is "*",
356+
* replacement is applied on all string, numeric or boolean columns.
357+
* @param replacement value replacement map. Key and value of `replacement` map must have
358+
* the same type, and can only be doubles, strings or booleans.
359+
* The map value can have nulls.
363360
*
364361
* @since 1.3.1
365362
*/
@@ -370,8 +367,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
370367
return df
371368
}
372369

373-
// replacementMap is either Map[String, String], Map[Double, Double], Map[Boolean,Boolean]
374-
// while value can have null
370+
// Convert the NumericType in replacement map to DoubleType,
371+
// while leaving StringType, BooleanType and null untouched.
375372
val replacementMap: Map[_, _] = replacement.map {
376373
case (k, v: String) => (k, v)
377374
case (k, v: Boolean) => (k, v)
@@ -381,7 +378,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
381378
case (k, v) => (convertToDouble(k), convertToDouble(v))
382379
}
383380

384-
// targetColumnType is either DoubleType or StringType or BooleanType
381+
// targetColumnType is either DoubleType, StringType or BooleanType,
382+
// depending on the type of first key in replacement map.
383+
// Only fields of targetColumnType will perform replacement.
385384
val targetColumnType = replacement.head._1 match {
386385
case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType
387386
case _: jl.Boolean => BooleanType

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
262262
assert(out1(4) === Row("Amy", null, null))
263263
assert(out1(5) === Row(null, null, null))
264264

265-
// Replace with null
265+
// Replace String with String and null
266266
val out2 = input.na.replace("name", Map(
267267
"Bob" -> "Bravo",
268268
"Alice" -> null
@@ -274,5 +274,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
274274
assert(out2(3).get(2).asInstanceOf[Double].isNaN)
275275
assert(out2(4) === Row("Amy", null, null))
276276
assert(out2(5) === Row(null, null, null))
277+
278+
// Replace Double with null
279+
val out3 = input.na.replace("age", Map[Any, Any](
280+
16 -> null
281+
)).collect()
282+
283+
assert(out3(0) === Row("Bob", null, 176.5))
284+
assert(out3(1) === Row("Alice", null, 164.3))
285+
assert(out3(2) === Row("David", 60, null))
286+
assert(out3(3).get(2).asInstanceOf[Double].isNaN)
287+
assert(out3(4) === Row("Amy", null, null))
288+
assert(out3(5) === Row(null, null, null))
289+
290+
// Replace String with null and then drop rows containing null
291+
checkAnswer(
292+
input.na.replace("name", Map(
293+
"Bob" -> null
294+
)).na.drop("name" :: Nil).select("name"),
295+
Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil)
277296
}
278297
}

0 commit comments

Comments
 (0)