Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ public static float getFloat(Object object, long offset) {
}

public static void putFloat(Object object, long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
} else if (value == -0.0f) {
value = 0.0f;
}
_UNSAFE.putFloat(object, offset, value);
}

Expand All @@ -128,6 +133,11 @@ public static double getDouble(Object object, long offset) {
}

public static void putDouble(Object object, long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
} else if (value == -0.0d) {
value = 0.0d;
}
_UNSAFE.putDouble(object, offset, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,18 @@ public void heapMemoryReuse() {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}

@Test
// SPARK-26021
public void writeMinusZeroIsReplacedWithZero() {
byte[] doubleBytes = new byte[Double.BYTES];
byte[] floatBytes = new byte[Float.BYTES];
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET);
float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET);

Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform));
Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Double.isNaN(value)) {
value = Double.NaN;
}
Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
}

Expand Down Expand Up @@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) {
}

protected final void writeFloat(long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(getBuffer(), offset, value);
}

protected final void writeDouble(long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
}
Platform.putDouble(getBuffer(), offset, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
"grouping expressions: [current_date(None)], value: [key: int, value: string], " +
"type: GroupBy]"))
}

test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") {
val colName = "i"
val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect()
val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect()

assert(doubles.length == 1)
assert(floats.length == 1)
// using compare since 0.0 == -0.0 is true
assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0)
assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0)
assert(doubles(0).getLong(1) == 3)
assert(floats(0).getLong(1) == 3)
}
}
5 changes: 4 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ object QueryTest {
def prepareRow(row: Row): Row = {
Row.fromSeq(row.toSeq.map {
case null => null
case d: java.math.BigDecimal => BigDecimal(d)
case bd: java.math.BigDecimal => BigDecimal(bd)
// Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+
case seq: Seq[_] => seq.map {
case b: java.lang.Byte => b.byteValue
Expand All @@ -303,6 +303,9 @@ object QueryTest {
// Convert array to Seq for easy equality check.
case b: Array[_] => b.toSeq
case r: Row => prepareRow(r)
// spark treats -0.0 as 0.0
case d: Double if d == -0.0d => 0.0d
case f: Float if f == -0.0f => 0.0f
case o => o
})
}
Expand Down