diff --git a/presto-docs/src/main/sphinx/functions/bitwise.rst b/presto-docs/src/main/sphinx/functions/bitwise.rst index 9c55858e2a275..5c4c464def033 100644 --- a/presto-docs/src/main/sphinx/functions/bitwise.rst +++ b/presto-docs/src/main/sphinx/functions/bitwise.rst @@ -31,7 +31,7 @@ Bitwise Functions .. function:: bitwise_shift_left(x, shift, bits) -> bigint Left shift operation on ``x`` (treated as ``bits``-bit integer) - shifted by ``shift``. + shifted by ``shift``:: SELECT bitwise_shift_left(7, 2, 4); -- 12 SELECT bitwise_shift_left(7, 2, 64); -- 28 @@ -39,16 +39,45 @@ Bitwise Functions .. function:: bitwise_logical_shift_right(x, shift, bits) -> bigint Logical right shift operation on ``x`` (treated as ``bits``-bit integer) - shifted by ``shift``. + shifted by ``shift``:: SELECT bitwise_logical_shift_right(7, 2, 4); -- 1 SELECT bitwise_logical_shift_right(-8, 2, 5); -- 6 .. function:: bitwise_arithmetic_shift_right(x, shift) -> bigint - Arithmetic right shift operation on ``x`` shifted by ``shift`` in 2's complement representation. + Arithmetic right shift operation on ``x`` shifted by ``shift`` in 2's complement representation:: SELECT bitwise_arithmetic_shift_right(-8, 2); -- -2 SELECT bitwise_arithmetic_shift_right(7, 2); -- 1 +Generic Shift Functions +----------------------- + +These three functions accept values of integral value types ``TINYINT``, ``SMALLINT``, ``INTEGER`` and ``BIGINT``, +and shift them by the amount given by ``shift``, returning a value of the same integral type. For all three +functions, the amount to shift is given by the bottom bits of the ``shift`` parameter, and higher bits of the +``shift`` parameter are ignored. + +.. function:: bitwise_left_shift(value, shift) -> [same as value] + + Returns the left shifted value of ``value``:: + + SELECT bitwise_left_shift(TINYINT '7', 2); -- 28 + SELECT bitwise_left_shift(TINYINT '-7', 2); -- -28 + +.. function:: bitwise_right_shift(value, shift, digits) -> [same as value] + + Returns the logical right shifted value of ``value``:: + + SELECT bitwise_right_shift(TINYINT '7', 2); -- 1 + SELECT bitwise_right_shift(SMALLINT -8, 2); -- 16382 + +.. function:: bitwise_right_shift_arithmetic(value, shift) -> [same as value] + + Returns the arithmetic right shifted value of ``value``:: + + SELECT bitwise_right_shift_arithmetic(BIGINT '-8', 2); -- -2 + SELECT bitwise_right_shift_arithmetic(SMALLINT '7', 2); -- 1 + See also :func:`bitwise_and_agg` and :func:`bitwise_or_agg`. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/BitwiseFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/BitwiseFunctions.java index 24e25621cdb59..e2b549b6acc92 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/BitwiseFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/BitwiseFunctions.java @@ -24,6 +24,12 @@ public final class BitwiseFunctions { private static final int MAX_BITS = 64; + private static final long TINYINT_MASK = 0b1111_1111L; + private static final long TINYINT_SIGNED_BIT = 0b1000_0000L; + private static final long SMALLINT_MASK = 0b1111_1111_1111_1111L; + private static final long SMALLINT_SIGNED_BIT = 0b1000_0000_0000_0000L; + private static final long INTEGER_MASK = 0x00_00_00_00_ff_ff_ff_ffL; + private static final long INTEGER_SIGNED_BIT = 0x00_00_00_00_00_80_00_00_00L; private BitwiseFunctions() {} @@ -133,4 +139,178 @@ public static long bitwiseArithmeticShiftRight(@SqlType(StandardTypes.BIGINT) lo return number >> shift; } + + @Description("bitwise left shift") + @ScalarFunction("bitwise_left_shift") + @SqlType(StandardTypes.TINYINT) + public static long bitwiseLeftShiftTinyint(@SqlType(StandardTypes.TINYINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + long shifted = (value << shift); + return preserveSign(shifted, TINYINT_MASK, TINYINT_SIGNED_BIT); + } + + @Description("bitwise left shift") + @ScalarFunction("bitwise_left_shift") + @SqlType(StandardTypes.SMALLINT) + public static long bitwiseLeftShiftSmallint(@SqlType(StandardTypes.SMALLINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + long shifted = (value << shift); + return preserveSign(shifted, SMALLINT_MASK, SMALLINT_SIGNED_BIT); + } + + @Description("bitwise left shift") + @ScalarFunction("bitwise_left_shift") + @SqlType(StandardTypes.INTEGER) + public static long bitwiseLeftShiftInteger(@SqlType(StandardTypes.INTEGER) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + long shifted = (value << shift); + return preserveSign(shifted, INTEGER_MASK, INTEGER_SIGNED_BIT); + } + + @Description("bitwise left shift") + @ScalarFunction("bitwise_left_shift") + @SqlType(StandardTypes.BIGINT) + public static long bitwiseLeftShiftBigint(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + return value << shift; + } + + private static long preserveSign(long shiftedValue, long mask, long signedBit) + { + if ((shiftedValue & signedBit) != 0) { + // Preserve the sign in 2's complement format + return shiftedValue | ~mask; + } + + return shiftedValue & mask; + } + + @Description("bitwise logical right shift") + @ScalarFunction("bitwise_right_shift") + @SqlType(StandardTypes.TINYINT) + public static long bitwiseRightShiftTinyint(@SqlType(StandardTypes.TINYINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + if (shift == 0) { + return value; + } + return (value & TINYINT_MASK) >>> shift; + } + + @Description("bitwise logical right shift") + @ScalarFunction("bitwise_right_shift") + @SqlType(StandardTypes.SMALLINT) + public static long bitwiseRightShiftSmallint(@SqlType(StandardTypes.SMALLINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + if (shift == 0) { + return value; + } + return (value & SMALLINT_MASK) >>> shift; + } + + @Description("bitwise logical right shift") + @ScalarFunction("bitwise_right_shift") + @SqlType(StandardTypes.INTEGER) + public static long bitwiseRightShiftInteger(@SqlType(StandardTypes.INTEGER) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + if (shift == 0) { + return value; + } + return (value & INTEGER_MASK) >>> shift; + } + + @Description("bitwise logical right shift") + @ScalarFunction("bitwise_right_shift") + @SqlType(StandardTypes.BIGINT) + public static long bitwiseRightShiftBigint(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + return 0L; + } + return value >>> shift; + } + + @Description("bitwise arithmetic right shift") + @ScalarFunction("bitwise_right_shift_arithmetic") + @SqlType(StandardTypes.TINYINT) + public static long bitwiseRightShiftArithmeticTinyint(@SqlType(StandardTypes.TINYINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + if (value >= 0) { + return 0L; + } + else { + return -1L; + } + } + return preserveSign(value, TINYINT_MASK, TINYINT_SIGNED_BIT) >> shift; + } + + @Description("bitwise arithmetic right shift") + @ScalarFunction("bitwise_right_shift_arithmetic") + @SqlType(StandardTypes.SMALLINT) + public static long bitwiseRightShiftArithmeticSmallint(@SqlType(StandardTypes.SMALLINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + if (value >= 0) { + return 0L; + } + else { + return -1L; + } + } + return preserveSign(value, SMALLINT_MASK, SMALLINT_SIGNED_BIT) >> shift; + } + + @Description("bitwise arithmetic right shift") + @ScalarFunction("bitwise_right_shift_arithmetic") + @SqlType(StandardTypes.INTEGER) + public static long bitwiseRightShiftArithmeticInteger(@SqlType(StandardTypes.INTEGER) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + if (value >= 0) { + return 0L; + } + else { + return -1L; + } + } + return preserveSign(value, INTEGER_MASK, INTEGER_SIGNED_BIT) >> shift; + } + + @Description("bitwise arithmetic right shift") + @ScalarFunction("bitwise_right_shift_arithmetic") + @SqlType(StandardTypes.BIGINT) + public static long bitwiseRightShiftArithmeticBigint(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.INTEGER) long shift) + { + if (shift >= MAX_BITS) { + if (value >= 0) { + return 0L; + } + else { + return -1L; + } + } + return value >> shift; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestBitwiseFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestBitwiseFunctions.java index c8b73f89becf7..c0377547b1ee3 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestBitwiseFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestBitwiseFunctions.java @@ -16,6 +16,9 @@ import org.testng.annotations.Test; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.common.type.TinyintType.TINYINT; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static java.lang.String.format; @@ -131,4 +134,77 @@ public void testBitwiseSra() assertInvalidFunction("bitwise_arithmetic_shift_right(7, -3)", INVALID_FUNCTION_ARGUMENT); } + + @Test + public void testBitwiseLeftShift() + { + assertFunction("bitwise_left_shift(TINYINT'7', 2)", TINYINT, (byte) (7 << 2)); + assertFunction("bitwise_left_shift(TINYINT '-7', 2)", TINYINT, (byte) (-7 << 2)); + assertFunction("bitwise_left_shift(TINYINT '1', 7)", TINYINT, (byte) (1 << 7)); + assertFunction("bitwise_left_shift(TINYINT '-128', 1)", TINYINT, (byte) 0); + assertFunction("bitwise_left_shift(TINYINT '-65', 1)", TINYINT, (byte) (-65 << 1)); + assertFunction("bitwise_left_shift(TINYINT '-7', 64)", TINYINT, (byte) 0); + assertFunction("bitwise_left_shift(TINYINT '-128', 0)", TINYINT, (byte) -128); + assertFunction("bitwise_left_shift(SMALLINT '7', 2)", SMALLINT, (short) (7 << 2)); + assertFunction("bitwise_left_shift(SMALLINT '-7', 2)", SMALLINT, (short) (-7 << 2)); + assertFunction("bitwise_left_shift(SMALLINT '1', 7)", SMALLINT, (short) (1 << 7)); + assertFunction("bitwise_left_shift(SMALLINT '-32768', 1)", SMALLINT, (short) 0); + assertFunction("bitwise_left_shift(SMALLINT '-65', 1)", SMALLINT, (short) (-65 << 1)); + assertFunction("bitwise_left_shift(SMALLINT '-7', 64)", SMALLINT, (short) 0); + assertFunction("bitwise_left_shift(SMALLINT '-32768', 0)", SMALLINT, (short) -32768); + assertFunction("bitwise_left_shift(INTEGER '7', 2)", INTEGER, 7 << 2); + assertFunction("bitwise_left_shift(INTEGER '-7', 2)", INTEGER, -7 << 2); + assertFunction("bitwise_left_shift(INTEGER '1', 7)", INTEGER, 1 << 7); + assertFunction("bitwise_left_shift(INTEGER '-2147483648', 1)", INTEGER, 0); + assertFunction("bitwise_left_shift(INTEGER '-65', 1)", INTEGER, -65 << 1); + assertFunction("bitwise_left_shift(INTEGER '-7', 64)", INTEGER, 0); + assertFunction("bitwise_left_shift(INTEGER '-2147483648', 0)", INTEGER, -2147483648); + assertFunction("bitwise_left_shift(BIGINT '7', 2)", BIGINT, 7L << 2); + assertFunction("bitwise_left_shift(BIGINT '-7', 2)", BIGINT, -7L << 2); + assertFunction("bitwise_left_shift(BIGINT '-7', 64)", BIGINT, 0L); + } + + @Test + public void testBitwiseRightShift() + { + assertFunction("bitwise_right_shift(TINYINT '7', 2)", TINYINT, (byte) (7 >>> 2)); + assertFunction("bitwise_right_shift(TINYINT '-7', 2)", TINYINT, (byte) 62); + assertFunction("bitwise_right_shift(TINYINT '-7', 64)", TINYINT, (byte) 0); + assertFunction("bitwise_right_shift(TINYINT '-128', 0)", TINYINT, (byte) -128); + assertFunction("bitwise_right_shift(SMALLINT '7', 2)", SMALLINT, (short) (7 >>> 2)); + assertFunction("bitwise_right_shift(SMALLINT '-7', 2)", SMALLINT, (short) 16382); + assertFunction("bitwise_right_shift(SMALLINT '-7', 64)", SMALLINT, (short) 0); + assertFunction("bitwise_right_shift(SMALLINT '-32768', 0)", SMALLINT, (short) -32768); + assertFunction("bitwise_right_shift(INTEGER '7', 2)", INTEGER, 7 >>> 2); + assertFunction("bitwise_right_shift(INTEGER '-7', 2)", INTEGER, 1073741822); + assertFunction("bitwise_right_shift(INTEGER '-7', 64)", INTEGER, 0); + assertFunction("bitwise_right_shift(INTEGER '-2147483648', 0)", INTEGER, -2147483648); + assertFunction("bitwise_right_shift(BIGINT '7', 2)", BIGINT, 7L >>> 2); + assertFunction("bitwise_right_shift(BIGINT '-7', 2)", BIGINT, -7L >>> 2); + assertFunction("bitwise_right_shift(BIGINT '-7', 64)", BIGINT, 0L); + } + + @Test + public void testBitwiseRightShiftArithmetic() + { + assertFunction("bitwise_right_shift_arithmetic(TINYINT '7', 2)", TINYINT, (byte) (7 >> 2)); + assertFunction("bitwise_right_shift_arithmetic(TINYINT '-7', 2)", TINYINT, (byte) (-7 >> 2)); + assertFunction("bitwise_right_shift_arithmetic(TINYINT '7', 64)", TINYINT, (byte) 0); + assertFunction("bitwise_right_shift_arithmetic(TINYINT '-7', 64)", TINYINT, (byte) -1); + assertFunction("bitwise_right_shift_arithmetic(TINYINT '-128', 0)", TINYINT, (byte) -128); + assertFunction("bitwise_right_shift_arithmetic(SMALLINT '7', 2)", SMALLINT, (short) (7 >> 2)); + assertFunction("bitwise_right_shift_arithmetic(SMALLINT '-7', 2)", SMALLINT, (short) (-7 >> 2)); + assertFunction("bitwise_right_shift_arithmetic(SMALLINT '7', 64)", SMALLINT, (short) 0); + assertFunction("bitwise_right_shift_arithmetic(SMALLINT '-7', 64)", SMALLINT, (short) -1); + assertFunction("bitwise_right_shift_arithmetic(SMALLINT '-32768', 0)", SMALLINT, (short) -32768); + assertFunction("bitwise_right_shift_arithmetic(INTEGER '7', 2)", INTEGER, (7 >> 2)); + assertFunction("bitwise_right_shift_arithmetic(INTEGER '-7', 2)", INTEGER, -7 >> 2); + assertFunction("bitwise_right_shift_arithmetic(INTEGER '7', 64)", INTEGER, 0); + assertFunction("bitwise_right_shift_arithmetic(INTEGER '-7', 64)", INTEGER, -1); + assertFunction("bitwise_right_shift_arithmetic(INTEGER '-2147483648', 0)", INTEGER, -2147483648); + assertFunction("bitwise_right_shift_arithmetic(BIGINT '7', 2)", BIGINT, 7L >> 2); + assertFunction("bitwise_right_shift_arithmetic(BIGINT '-7', 2)", BIGINT, -7L >> 2); + assertFunction("bitwise_right_shift_arithmetic(BIGINT '7', 64)", BIGINT, 0L); + assertFunction("bitwise_right_shift_arithmetic(BIGINT '-7', 64)", BIGINT, -1L); + } }