diff --git a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java index 57bf946963433..b06021f5dc160 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java @@ -196,6 +196,9 @@ public static double getRegressionSyy(RegressionState state) public static double getRegressionR2(RegressionState state) { + if (state.getM2X() != 0 && state.getM2Y() == 0) { + return 1.0; + } return Math.pow(state.getC2(), 2) / (state.getM2X() * state.getM2Y()); } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java index db0337a000791..d0cb1b5540551 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestDoubleRegrR2Aggregation.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.aggregation; import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; import static com.google.common.base.Preconditions.checkArgument; @@ -45,6 +46,19 @@ else if (length == 1) { } } + @Test + public void testTwoSpecialCase() + { + // when m2x = 0, result is null + Double[] y = new Double[] {1.0, 1.0, 1.0, 1.0, 1.0}; + Double[] x = new Double[] {1.0, 1.0, 1.0, 1.0, 1.0}; + testAggregation(null, createDoublesBlock(y), createDoublesBlock(x)); + + // when m2x != 0 and m2y = 0, result is 1.0 + x = new Double[] {1.0, 2.0, 3.0, 4.0, 5.0}; + testAggregation(1.0, createDoublesBlock(y), createDoublesBlock(x)); + } + @Override protected void testNonTrivialAggregation(Double[] y, Double[] x) { diff --git a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java index 54cd9db4016d6..a93b5d18a3dd0 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/aggregation/TestRealRegrR2Aggregation.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.aggregation; import org.apache.commons.math3.stat.regression.SimpleRegression; +import org.testng.annotations.Test; import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; import static com.google.common.base.Preconditions.checkArgument; @@ -45,6 +46,19 @@ else if (length == 1) { } } + @Test + public void testTwoSpecialCase() + { + // when m2x = 0, result is null + Float[] y = new Float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + Float[] x = new Float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + testAggregation(null, createBlockOfReals(y), createBlockOfReals(x)); + + // when m2x != 0 and m2y = 0, result is 1.0 + x = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + testAggregation(1.0f, createBlockOfReals(y), createBlockOfReals(x)); + } + @Override protected void testNonTrivialAggregation(Float[] y, Float[] x) {