diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java b/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java index 9216ab7afc76d..1e82f914a1297 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/DoubleType.java @@ -56,6 +56,12 @@ public boolean isOrderable() @Override public Object getObjectValue(SqlFunctionProperties properties, Block block, int position) + { + return getObject(block, position); + } + + @Override + public Object getObject(Block block, int position) { if (block.isNull(position)) { return null; @@ -63,6 +69,12 @@ public Object getObjectValue(SqlFunctionProperties properties, Block block, int return longBitsToDouble(block.getLong(position)); } + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + writeDouble(blockBuilder, ((Number) value).doubleValue()); + } + @Override public boolean equalTo(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java b/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java index 118caf4c6e67e..d6a5a3087b488 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/RealType.java @@ -36,6 +36,12 @@ private RealType() @Override public Object getObjectValue(SqlFunctionProperties properties, Block block, int position) + { + return getObject(block, position); + } + + @Override + public Object getObject(Block block, int position) { if (block.isNull(position)) { return null; @@ -43,6 +49,12 @@ public Object getObjectValue(SqlFunctionProperties properties, Block block, int return intBitsToFloat(block.getInt(position)); } + @Override + public void writeObject(BlockBuilder blockBuilder, Object value) + { + writeLong(blockBuilder, Float.floatToIntBits(((Number) value).floatValue())); + } + @Override public boolean equalTo(Block leftBlock, int leftPosition, Block rightBlock, int rightPosition) { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java index b4975a6adb164..629775a1cc6b7 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayNormalizeFunction.java @@ -35,9 +35,6 @@ @Description("Normalizes an array by dividing each element by the p-norm of the array.") public final class ArrayNormalizeFunction { - private static final ValueAccessor DOUBLE_VALUE_ACCESSOR = new DoubleValueAccessor(); - private static final ValueAccessor REAL_VALUE_ACCESSOR = new RealValueAccessor(); - private ArrayNormalizeFunction() {} @TypeParameter("T") @@ -49,7 +46,7 @@ public static Block normalizeDoubleArray( @SqlType("array(T)") Block block, @SqlType("T") double p) { - return normalizeArray(elementType, block, p, DOUBLE_VALUE_ACCESSOR); + return normalizeArray(elementType, block, p); } @TypeParameter("T") @@ -61,10 +58,10 @@ public static Block normalizeRealArray( @SqlType("array(T)") Block block, @SqlType("T") long p) { - return normalizeArray(elementType, block, Float.intBitsToFloat((int) p), REAL_VALUE_ACCESSOR); + return normalizeArray(elementType, block, Float.intBitsToFloat((int) p)); } - private static Block normalizeArray(Type elementType, Block block, double p, ValueAccessor valueAccessor) + private static Block normalizeArray(Type elementType, Block block, double p) { if (!(elementType instanceof RealType) && !(elementType instanceof DoubleType)) { throw new PrestoException( @@ -83,7 +80,7 @@ private static Block normalizeArray(Type elementType, Block block, double p, Val if (block.isNull(i)) { return null; } - pNorm += Math.pow(Math.abs(valueAccessor.getValue(elementType, block, i)), p); + pNorm += Math.pow(Math.abs(((Number) elementType.getObject(block, i)).doubleValue()), p); } if (pNorm == 0) { return block; @@ -91,47 +88,8 @@ private static Block normalizeArray(Type elementType, Block block, double p, Val pNorm = Math.pow(pNorm, 1.0 / p); BlockBuilder blockBuilder = elementType.createBlockBuilder(null, elementCount); for (int i = 0; i < elementCount; i++) { - valueAccessor.writeValue(elementType, blockBuilder, valueAccessor.getValue(elementType, block, i) / pNorm); + elementType.writeObject(blockBuilder, ((Number) elementType.getObject(block, i)).doubleValue() / pNorm); } return blockBuilder.build(); } - - private interface ValueAccessor - { - double getValue(Type elementType, Block block, int position); - - void writeValue(Type elementType, BlockBuilder blockBuilder, double value); - } - - private static class DoubleValueAccessor - implements ValueAccessor - { - @Override - public double getValue(Type elementType, Block block, int position) - { - return elementType.getDouble(block, position); - } - - @Override - public void writeValue(Type elementType, BlockBuilder blockBuilder, double value) - { - elementType.writeDouble(blockBuilder, value); - } - } - - private static class RealValueAccessor - implements ValueAccessor - { - @Override - public double getValue(Type elementType, Block block, int position) - { - return Float.intBitsToFloat((int) elementType.getLong(block, position)); - } - - @Override - public void writeValue(Type elementType, BlockBuilder blockBuilder, double value) - { - elementType.writeLong(blockBuilder, Float.floatToIntBits((float) value)); - } - } } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestDoubleType.java b/presto-main/src/test/java/com/facebook/presto/type/TestDoubleType.java index e1aac2377f0fd..59ce4a739c64f 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestDoubleType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestDoubleType.java @@ -68,4 +68,27 @@ public void testNaNHash() assertEquals(DOUBLE.hash(blockBuilder, 0), DOUBLE.hash(blockBuilder, 2)); assertEquals(DOUBLE.hash(blockBuilder, 0), DOUBLE.hash(blockBuilder, 3)); } + + @Test + public void testGetAndWrite() + { + BlockBuilder blockBuilder = DOUBLE.createFixedSizeBlockBuilder(5); + DOUBLE.writeDouble(blockBuilder, 1.1); + DOUBLE.writeObject(blockBuilder, 1.1); + DOUBLE.writeDouble(blockBuilder, Double.NaN); + DOUBLE.writeObject(blockBuilder, Double.NaN); + // Test passing an integer. + DOUBLE.writeObject(blockBuilder, 4); + + Block block = blockBuilder.build(); + assertEquals(DOUBLE.getDouble(block, 0), 1.1); + assertEquals(DOUBLE.getObject(block, 0), 1.1); + assertEquals(DOUBLE.getDouble(block, 1), 1.1); + assertEquals(DOUBLE.getObject(block, 1), 1.1); + assertEquals(DOUBLE.getDouble(block, 2), Double.NaN); + assertEquals(DOUBLE.getObject(block, 2), Double.NaN); + assertEquals(DOUBLE.getDouble(block, 3), Double.NaN); + assertEquals(DOUBLE.getObject(block, 3), Double.NaN); + assertEquals(DOUBLE.getObject(block, 4), 4.0); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/type/TestRealType.java b/presto-main/src/test/java/com/facebook/presto/type/TestRealType.java index c81577292596b..c43773a51fae0 100644 --- a/presto-main/src/test/java/com/facebook/presto/type/TestRealType.java +++ b/presto-main/src/test/java/com/facebook/presto/type/TestRealType.java @@ -71,4 +71,27 @@ public void testNaNHash() assertEquals(REAL.hash(blockBuilder, 0), REAL.hash(blockBuilder, 2)); assertEquals(REAL.hash(blockBuilder, 0), REAL.hash(blockBuilder, 3)); } + + @Test + public void testGetAndWrite() + { + BlockBuilder blockBuilder = REAL.createFixedSizeBlockBuilder(5); + REAL.writeLong(blockBuilder, floatToIntBits(1.1f)); + REAL.writeObject(blockBuilder, 1.1f); + REAL.writeLong(blockBuilder, floatToIntBits(Float.NaN)); + REAL.writeObject(blockBuilder, Float.NaN); + // Test passing an integer. + REAL.writeObject(blockBuilder, 4); + Block block = blockBuilder.build(); + + assertEquals(intBitsToFloat((int) REAL.getLong(block, 0)), 1.1f); + assertEquals(REAL.getObject(block, 0), 1.1f); + assertEquals(intBitsToFloat((int) REAL.getLong(block, 1)), 1.1f); + assertEquals(REAL.getObject(block, 1), 1.1f); + assertEquals(intBitsToFloat((int) REAL.getLong(block, 2)), Float.NaN); + assertEquals(REAL.getObject(block, 2), Float.NaN); + assertEquals(intBitsToFloat((int) REAL.getLong(block, 3)), Float.NaN); + assertEquals(REAL.getObject(block, 3), Float.NaN); + assertEquals(REAL.getObject(block, 4), 4.0f); + } }