From 02c601538b508effba116deb1ca894bb4bbf544d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Dec 2018 14:44:30 +0800 Subject: [PATCH 1/2] Revert "[SPARK-26021][2.4][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter" This reverts commit 33460c58a9274e22bd662858c71292275ae4aa24. --- .../org/apache/spark/unsafe/Platform.java | 10 ++++++ .../spark/unsafe/PlatformUtilSuite.java | 14 ++++++++ .../expressions/codegen/UnsafeWriter.java | 35 ------------------- .../codegen/UnsafeRowWriterSuite.scala | 20 ----------- .../apache/spark/sql/DataFrameJoinSuite.scala | 12 ------- .../sql/DataFrameWindowFunctionsSuite.scala | 14 -------- 6 files changed, 24 insertions(+), 81 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..bc94f2171228a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -120,6 +120,11 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if (value == -0.0f) { + value = 0.0f; + } _UNSAFE.putFloat(object, offset, value); } @@ -128,6 +133,11 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if (value == -0.0d) { + value = 0.0d; + } _UNSAFE.putDouble(object, offset, value); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..ab34324eb54cc 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,4 +157,18 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } + + @Test + // SPARK-26021 + public void writeMinusZeroIsReplacedWithZero() { + byte[] doubleBytes = new byte[Double.BYTES]; + byte[] floatBytes = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); + double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); + float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); + + Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); + Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 7553ab8cf7000..95263a0da95a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -198,46 +198,11 @@ protected final void writeLong(long offset, long value) { Platform.putLong(getBuffer(), offset, value); } - // We need to take care of NaN and -0.0 in several places: - // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be - // treated as same. - // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong - // to the same group. - // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be - // treated as same. - // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` - // should be treated as same. - // - // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we - // recursively compare the fields/elements, so it's also fine. - // - // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different - // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. - // - // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing - // float/double columns and nested fields to `UnsafeRow`. - // - // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract - // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex - // types, so nested float/double may not be normalized. We need to make sure that all the unsafe - // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during - // creation. protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } else if (value == -0.0f) { - value = 0.0f; - } Platform.putFloat(getBuffer(), offset, value); } - // See comments for `writeFloat`. protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } else if (value == -0.0d) { - value = 0.0d; - } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index 22e1fa6dfed4f..fb651b76fc16d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -50,24 +50,4 @@ class UnsafeRowWriterSuite extends SparkFunSuite { assert(res1 == res2) } - test("SPARK-26021: normalize float/double NaN and -0.0") { - val unsafeRowWriter1 = new UnsafeRowWriter(4) - unsafeRowWriter1.resetRowWriter() - unsafeRowWriter1.write(0, Float.NaN) - unsafeRowWriter1.write(1, Double.NaN) - unsafeRowWriter1.write(2, 0.0f) - unsafeRowWriter1.write(3, 0.0) - val res1 = unsafeRowWriter1.getRow - - val unsafeRowWriter2 = new UnsafeRowWriter(4) - unsafeRowWriter2.resetRowWriter() - unsafeRowWriter2.write(0, 0.0f/0.0f) - unsafeRowWriter2.write(1, 0.0/0.0) - unsafeRowWriter2.write(2, -0.0f) - unsafeRowWriter2.write(3, -0.0) - val res2 = unsafeRowWriter2.getRow - - // The two rows should be the equal - assert(res1 == res2) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c9f41ab1c0179..e6b30f9956daf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -295,16 +295,4 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan } } - - test("NaN and -0.0 in join keys") { - val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") - val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") - val joined = df1.join(df2, Seq("f", "d")) - checkAnswer(joined, Seq( - Row(Float.NaN, Double.NaN), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(0.0f, 0.0))) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index bbeb1d10ba7ec..97a843978f0bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -658,18 +658,4 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { |GROUP BY a |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin)) } - - test("NaN and -0.0 in window partition keys") { - val df = Seq( - (Float.NaN, Double.NaN, 1), - (0.0f/0.0f, 0.0/0.0, 1), - (0.0f, 0.0, 1), - (-0.0f, -0.0, 1)).toDF("f", "d", "i") - val result = df.select($"f", count("i").over(Window.partitionBy("f", "d"))) - checkAnswer(result, Seq( - Row(Float.NaN, 2), - Row(Float.NaN, 2), - Row(0.0f, 2), - Row(0.0f, 2))) - } } From 9be77a86efd521c4f89c79c273445ebcaa87cc29 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Dec 2018 14:50:49 +0800 Subject: [PATCH 2/2] Revert "[SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float" This reverts commit d63ab5a4f5aeecfa227edc84aa38e866446f5238. --- .../java/org/apache/spark/unsafe/Platform.java | 10 ---------- .../org/apache/spark/unsafe/PlatformUtilSuite.java | 14 -------------- .../spark/sql/catalyst/expressions/UnsafeRow.java | 6 ++++++ .../catalyst/expressions/codegen/UnsafeWriter.java | 6 ++++++ .../apache/spark/sql/DataFrameAggregateSuite.scala | 14 -------------- .../scala/org/apache/spark/sql/QueryTest.scala | 5 +---- 6 files changed, 13 insertions(+), 42 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index bc94f2171228a..aca6fca00c48b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -120,11 +120,6 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } else if (value == -0.0f) { - value = 0.0f; - } _UNSAFE.putFloat(object, offset, value); } @@ -133,11 +128,6 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } else if (value == -0.0d) { - value = 0.0d; - } _UNSAFE.putDouble(object, offset, value); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index ab34324eb54cc..3ad9ac7b4de9c 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -157,18 +157,4 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } - - @Test - // SPARK-26021 - public void writeMinusZeroIsReplacedWithZero() { - byte[] doubleBytes = new byte[Double.BYTES]; - byte[] floatBytes = new byte[Float.BYTES]; - Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); - Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); - double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); - float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); - - Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); - Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9bf9452855f5f..a76e6ef8c91c1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -224,6 +224,9 @@ public void setLong(int ordinal, long value) { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Double.isNaN(value)) { + value = Double.NaN; + } Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -252,6 +255,9 @@ public void setByte(int ordinal, byte value) { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Float.isNaN(value)) { + value = Float.NaN; + } Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 95263a0da95a8..2781655002000 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -199,10 +199,16 @@ protected final void writeLong(long offset, long value) { } protected final void writeFloat(long offset, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } Platform.putFloat(getBuffer(), offset, value); } protected final void writeDouble(long offset, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 41dc72de49be5..d0106c44b7db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -727,18 +727,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } - - test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { - val colName = "i" - val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() - val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() - - assert(doubles.length == 1) - assert(floats.length == 1) - // using compare since 0.0 == -0.0 is true - assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) - assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) - assert(doubles(0).getLong(1) == 3) - assert(floats(0).getLong(1) == 3) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8ba67239fb907..baca9c1cfb9a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -289,7 +289,7 @@ object QueryTest { def prepareRow(row: Row): Row = { Row.fromSeq(row.toSeq.map { case null => null - case bd: java.math.BigDecimal => BigDecimal(bd) + case d: java.math.BigDecimal => BigDecimal(d) // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ case seq: Seq[_] => seq.map { case b: java.lang.Byte => b.byteValue @@ -303,9 +303,6 @@ object QueryTest { // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) - // spark treats -0.0 as 0.0 - case d: Double if d == -0.0d => 0.0d - case f: Float if f == -0.0f => 0.0f case o => o }) }