diff --git a/presto-docs/src/main/sphinx/functions/math.rst b/presto-docs/src/main/sphinx/functions/math.rst index 52f805c8e9d81..36d165091e234 100644 --- a/presto-docs/src/main/sphinx/functions/math.rst +++ b/presto-docs/src/main/sphinx/functions/math.rst @@ -104,6 +104,18 @@ Mathematical Functions Returns a pseudo-random number between 0 and n (exclusive). +.. function:: secure_rand() -> double + + This is an alias for :func:`secure_random()`. + +.. function:: secure_random() -> double + + Returns a cryptographically secure random value in the range 0.0 <= x < 1.0. + +.. function:: secure_random(lower, upper) -> [same as input] + + Returns a cryptographically secure random value in the range lower <= x < upper, where lower < upper. + .. function:: round(x) -> [same as input] Returns ``x`` rounded to the nearest integer. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index 2efeb0e1a4c18..8b545374eec95 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -40,6 +40,8 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; import java.util.concurrent.ThreadLocalRandom; import static com.facebook.presto.common.type.Decimals.longTenToNth; @@ -56,6 +58,7 @@ import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.unscaledDecimalToUnscaledLong; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.spi.function.FunctionKind.SCALAR; import static com.facebook.presto.type.DecimalOperators.modulusScalarFunction; @@ -93,6 +96,13 @@ public final class MathFunctions } } + private static final String SECURE_RANDOM_ALGORITHM; + + static { + String os = System.getProperty("os.name"); + SECURE_RANDOM_ALGORITHM = os.startsWith("Windows") ? "SHA1PRNG" : "NativePRNGNonBlocking"; + } + private MathFunctions() {} @Description("absolute value") @@ -685,6 +695,105 @@ public static long random(@SqlType(StandardTypes.BIGINT) long value) return ThreadLocalRandom.current().nextLong(value); } + @Description("a cryptographically secure random number between 0 and 1 (exclusive)") + @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) + @SqlType(StandardTypes.DOUBLE) + public static double secure_random() + { + try { + SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); + return random.nextDouble(); + } + catch (NoSuchAlgorithmException e) { + throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); + } + } + + @Description("a cryptographically secure random number between lower and upper (exclusive)") + @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) + @SqlType(StandardTypes.DOUBLE) + public static double secure_random(@SqlType(StandardTypes.DOUBLE) double lower, @SqlType(StandardTypes.DOUBLE) double upper) + { + checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); + try { + SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); + return random.doubles(lower, upper) + .findFirst() + .getAsDouble(); + } + catch (NoSuchAlgorithmException e) { + throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); + } + } + + @Description("a cryptographically secure random number between lower and upper (exclusive)") + @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) + @SqlType(StandardTypes.TINYINT) + public static long secureRandomTinyint(@SqlType(StandardTypes.TINYINT) long lower, @SqlType(StandardTypes.TINYINT) long upper) + { + checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); + try { + SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); + return random.ints((int) lower, (int) upper) + .findFirst() + .getAsInt(); + } + catch (NoSuchAlgorithmException e) { + throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); + } + } + + @Description("a cryptographically secure random number between lower and upper (exclusive)") + @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) + @SqlType(StandardTypes.SMALLINT) + public static long secureRandomSmallint(@SqlType(StandardTypes.SMALLINT) long lower, @SqlType(StandardTypes.SMALLINT) long upper) + { + checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); + try { + SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); + return random.ints((int) lower, (int) upper) + .findFirst() + .getAsInt(); + } + catch (NoSuchAlgorithmException e) { + throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); + } + } + + @Description("a cryptographically secure random number between lower and upper (exclusive)") + @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) + @SqlType(StandardTypes.INTEGER) + public static long secureRandomInteger(@SqlType(StandardTypes.INTEGER) long lower, @SqlType(StandardTypes.INTEGER) long upper) + { + checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); + try { + SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); + return random.ints((int) lower, (int) upper) + .findFirst() + .getAsInt(); + } + catch (NoSuchAlgorithmException e) { + throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); + } + } + + @Description("a cryptographically secure random number between lower and upper (exclusive)") + @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) + @SqlType(StandardTypes.BIGINT) + public static long secureRandomBigint(@SqlType(StandardTypes.BIGINT) long lower, @SqlType(StandardTypes.BIGINT) long upper) + { + checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); + try { + SecureRandom random = SecureRandom.getInstance(SECURE_RANDOM_ALGORITHM); + return random.longs(lower, upper) + .findFirst() + .getAsLong(); + } + catch (NoSuchAlgorithmException e) { + throw new PrestoException(NOT_SUPPORTED, SECURE_RANDOM_ALGORITHM + " is not supported in your OS", e); + } + } + @Description("inverse of normal cdf given a mean, std, and probability") @ScalarFunction @SqlType(StandardTypes.DOUBLE) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index dd9f073d8153d..377d837f8ccf1 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -693,6 +693,22 @@ public void testRandom() assertInvalidFunction("rand(-3000000000)", "bound must be positive"); } + @Test + public void testSecureRandom() + { + // secure_random is non-deterministic + functionAssertions.tryEvaluateWithAll("secure_rand()", DOUBLE, TEST_SESSION); + functionAssertions.tryEvaluateWithAll("secure_random()", DOUBLE, TEST_SESSION); + functionAssertions.tryEvaluateWithAll("secure_random(0, 1000)", INTEGER, TEST_SESSION); + functionAssertions.tryEvaluateWithAll("secure_random(0, 3000000000)", BIGINT, TEST_SESSION); + functionAssertions.tryEvaluateWithAll("secure_random(-3000000000, -1)", BIGINT, TEST_SESSION); + functionAssertions.tryEvaluateWithAll("secure_rand(-3000000000, 3000000000)", BIGINT, TEST_SESSION); + functionAssertions.tryEvaluateWithAll("secure_random(DECIMAL '0.0', DECIMAL '1.0')", DOUBLE, TEST_SESSION); + + assertInvalidFunction("secure_random(1, 1)", "upper bound must be greater than lower bound"); + assertInvalidFunction("secure_random(DECIMAL '5.0', DECIMAL '-5.0')", "upper bound must be greater than lower bound"); + } + @Test public void testRound() {