Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ public static float getFloat(Object object, long offset) {
}

public static void putFloat(Object object, long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
} else if (value == -0.0f) {
value = 0.0f;
}
_UNSAFE.putFloat(object, offset, value);
}

Expand All @@ -128,6 +133,11 @@ public static double getDouble(Object object, long offset) {
}

public static void putDouble(Object object, long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
} else if (value == -0.0d) {
value = 0.0d;
}
_UNSAFE.putDouble(object, offset, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,18 @@ public void heapMemoryReuse() {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}

@Test
// SPARK-26021
public void writeMinusZeroIsReplacedWithZero() {
byte[] doubleBytes = new byte[Double.BYTES];
byte[] floatBytes = new byte[Float.BYTES];
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET);
float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET);

Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform));
Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Double.isNaN(value)) {
value = Double.NaN;
}
Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
}

Expand Down Expand Up @@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) {
}

protected final void writeFloat(long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(getBuffer(), offset, value);
}

protected final void writeDouble(long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
}
Platform.putDouble(getBuffer(), offset, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
"grouping expressions: [current_date(None)], value: [key: int, value: string], " +
"type: GroupBy]"))
}

test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") {
val colName = "i"
val doubles = Seq(0.0d, 0.0d, -0.0d).toDF(colName).groupBy(colName).count().collect()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this change, doubles already has expected results:

Seq(0.0d, 0.0d, -0.0d).toDF(colName).groupBy(colName).count().show()

+---+-----+
|  i|count|
+---+-----+
|0.0|    3|
+---+-----+

Do you know why?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yes, if codegen is enabled a generated FastHashMap is used for the partial grouping before the shuffle. This map doesn't separate 0.0 and -0.0. In addition there are 2 threads for 2 partitions in the unit test. In the doubles Seq order each partition is grouped in 0.0 and so after the shuffle they are being merged.
In the floats case the order of the elements in the Seq is different so in the first grouping we get 1 partition on 0.0 and the other on -0.0 and so after the shuffle they are being treated as different groups (before the fix).
Now I remember this is the reason I originally disabled codegen in the test.
I think I'll just reorder the doubles Seq as well so it manifests the bug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is better to reorder the doubles.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example:

Seq(0.0d, 0.0d, -0.0d).toDF(colName).groupBy(colName).count().show()
+---+-----+
|  i|count|
+---+-----+
|0.0|    3|
+---+-----+
Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().show()
+----+-----+
|   i|count|
+----+-----+
| 0.0|    1|
|-0.0|    2|
+----+-----+
Seq(-0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().show()
+----+-----+
|   i|count|
+----+-----+
|-0.0|    3|
+----+-----+

val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect()

assert(doubles.length == 1)
assert(floats.length == 1)
// using compare since 0.0 == -0.0 is true
assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0)
assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0)
assert(doubles(0).getLong(1) == 3)
assert(floats(0).getLong(1) == 3)
}
}