diff --git a/presto-docs/src/main/sphinx/functions/math.rst b/presto-docs/src/main/sphinx/functions/math.rst index f5f595d52d62d..20b712d47fc3a 100644 --- a/presto-docs/src/main/sphinx/functions/math.rst +++ b/presto-docs/src/main/sphinx/functions/math.rst @@ -73,6 +73,18 @@ Mathematical Functions The mean and value v must be real values and the standard deviation must be a real and positive value. +.. function:: inverse_beta_cdf(a, b, p) -> double + + Compute the inverse of the Beta cdf with given a, b parameters for the cumulative + probability (p): P(N < n). The a, b parameters must be positive real values. + The probability p must lie on the interval [0, 1]. + +.. function:: beta_cdf(a, b, v) -> double + + Compute the Beta cdf with given a, b parameters: P(N < v; a, b). + The a, b parameters must be positive real numbers and value v must be a real value. + The value v must lie on the interval [0, 1]. + .. function:: ln(x) -> double Returns the natural logarithm of ``x``. diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index 3947cd0cc00d7..50d9e68658530 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -30,6 +30,7 @@ import com.facebook.presto.type.LiteralParameter; import com.google.common.primitives.Doubles; import io.airlift.slice.Slice; +import org.apache.commons.math3.distribution.BetaDistribution; import org.apache.commons.math3.special.Erf; import java.math.BigInteger; @@ -629,6 +630,36 @@ public static double normalCdf( return 0.5 * (1 + Erf.erf((value - mean) / (standardDeviation * Math.sqrt(2)))); } + @Description("inverse of Beta cdf given a, b parameters and probability") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double inverseBetaCdf( + @SqlType(StandardTypes.DOUBLE) double a, + @SqlType(StandardTypes.DOUBLE) double b, + @SqlType(StandardTypes.DOUBLE) double p) + { + checkCondition(p >= 0 && p <= 1, INVALID_FUNCTION_ARGUMENT, "p must be in the interval [0, 1]"); + checkCondition(a > 0, INVALID_FUNCTION_ARGUMENT, "a must be > 0"); + checkCondition(b > 0, INVALID_FUNCTION_ARGUMENT, "b must be > 0"); + BetaDistribution distribution = new BetaDistribution(null, a, b, BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + return distribution.inverseCumulativeProbability(p); + } + + @Description("Beta cdf given the a, b parameters and value") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double betaCdf( + @SqlType(StandardTypes.DOUBLE) double a, + @SqlType(StandardTypes.DOUBLE) double b, + @SqlType(StandardTypes.DOUBLE) double value) + { + checkCondition(value >= 0 && value <= 1, INVALID_FUNCTION_ARGUMENT, "value must be in the interval [0, 1]"); + checkCondition(a > 0, INVALID_FUNCTION_ARGUMENT, "a must be > 0"); + checkCondition(b > 0, INVALID_FUNCTION_ARGUMENT, "b must be > 0"); + BetaDistribution distribution = new BetaDistribution(null, a, b, BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + return distribution.cumulativeProbability(value); + } + @Description("round to nearest integer") @ScalarFunction("round") @SqlType(StandardTypes.TINYINT) diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index 9010af5c45fd9..d1dbbaac18071 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -1343,6 +1343,35 @@ public void testNormalCdf() assertInvalidFunction("normal_cdf(0, nan(), 0.1985)", "standardDeviation must > 0"); } + @Test + public void testInverseBetaCdf() + { + assertFunction("inverse_beta_cdf(3, 3.6, 0.0)", DOUBLE, 0.0); + assertFunction("inverse_beta_cdf(3, 3.6, 1.0)", DOUBLE, 1.0); + assertFunction("inverse_beta_cdf(3, 3.6, 0.3)", DOUBLE, 0.3469675485440618); + assertFunction("inverse_beta_cdf(3, 3.6, 0.95)", DOUBLE, 0.7600272463100223); + + assertInvalidFunction("inverse_beta_cdf(0, 3, 0.5)", "a must be > 0"); + assertInvalidFunction("inverse_beta_cdf(3, 0, 0.5)", "b must be > 0"); + assertInvalidFunction("inverse_beta_cdf(3, 5, -0.1)", "p must be in the interval [0, 1]"); + assertInvalidFunction("inverse_beta_cdf(3, 5, 1.1)", "p must be in the interval [0, 1]"); + } + + @Test + public void testBetaCdf() + throws Exception + { + assertFunction("beta_cdf(3, 3.6, 0.0)", DOUBLE, 0.0); + assertFunction("beta_cdf(3, 3.6, 1.0)", DOUBLE, 1.0); + assertFunction("beta_cdf(3, 3.6, 0.3)", DOUBLE, 0.21764809997679938); + assertFunction("beta_cdf(3, 3.6, 0.9)", DOUBLE, 0.9972502881611551); + + assertInvalidFunction("beta_cdf(0, 3, 0.5)", "a must be > 0"); + assertInvalidFunction("beta_cdf(3, 0, 0.5)", "b must be > 0"); + assertInvalidFunction("beta_cdf(3, 5, -0.1)", "value must be in the interval [0, 1]"); + assertInvalidFunction("beta_cdf(3, 5, 1.1)", "value must be in the interval [0, 1]"); + } + @Test public void testWilsonInterval() {