From 858cac099a8bde0da3ccab787fd90ba4115ae7a5 Mon Sep 17 00:00:00 2001 From: Artem Selishchev Date: Mon, 11 Aug 2025 22:47:45 -0700 Subject: [PATCH] [presto] Move out M2Y from RegressionState for regr_slope and regr_intercept functions (#25475) Summary: ## Context Currently we don't enforce intermediate/return type are the same in Coordinator and Prestissimo Worker. Velox creates vectors for intermediate/return results based on a plan that comes from Coordinator. Then Prestissimo tries to use those vector and not crash. In practise we had a crash some time ago due to such a mismatch (D74199165). And I added validation to Velox to catch such kind of mismatches early: https://github.com/facebookincubator/velox/pull/13322 But we wasn't able to enable it in prod, because the validation failed for "regr_slope" and "regr_intercept" functions. ## What's changed? In this diff I'm fixing "regr_slope" and "regr_intercept" intermediate types. Basically in Java `AggregationState` for all these functions is the same: ``` AggregationFunction("regr_slope") AggregationFunction("regr_intercept") AggregationFunction("regr_sxy") AggregationFunction("regr_sxx") AggregationFunction("regr_syy") AggregationFunction("regr_r2") AggregationFunction("regr_count") AggregationFunction("regr_avgy") AggregationFunction("regr_avgx") ``` But in Prestissimo the state storage is more optimal: ``` AggregationFunction("regr_slope") AggregationFunction("regr_intercept") ``` These 2 aggregation functions don't have M2Y field. And this is more efficient, because we don't waste memory and CPU on the field, that aren't needed. So I moved M2Y to extended class, the same as it works in Velox: https://github.com/facebookincubator/velox/blob/main/velox/functions/prestosql/aggregates/CovarianceAggregates.cpp?fbclid=IwY2xjawLRTetleHRuA2FlbQIxMQBicmlkETFiT0N3UFR0M2VKOHl6MHRhAR6KRQ1VUQdCkZXzwj14sMQrVZ-R9QBH1utuGJb5U_lyGzDwt8PwV317QRVNJg_aem_-ePxZ-fHO5MNgfUmayVJFA#L326-L337 No major changes, mostly just reorganized the code. ## Next steps In this diff I'm trying to apply the same optimization to Java. With this fix, the signatures will become the same in Java and Prestissimo and we will be able to enable the validation Differential Revision: D77625566 --- ...uiltInTypeAndFunctionNamespaceManager.java | 4 + .../aggregation/AggregationUtils.java | 25 ++- .../DoubleRegressionAggregation.java | 103 ------------ .../DoubleRegressionExtendedAggregation.java | 149 ++++++++++++++++++ .../RealRegressionAggregation.java | 103 ------------ .../RealRegressionExtendedAggregation.java | 149 ++++++++++++++++++ .../state/ExtendedRegressionState.java | 22 +++ .../aggregation/state/RegressionState.java | 4 - 8 files changed, 345 insertions(+), 214 deletions(-) create mode 100644 presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 897f6657f4f3a..8594d56a63b79 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -68,6 +68,7 @@ import com.facebook.presto.operator.aggregation.DoubleCovarianceAggregation; import com.facebook.presto.operator.aggregation.DoubleHistogramAggregation; import com.facebook.presto.operator.aggregation.DoubleRegressionAggregation; +import com.facebook.presto.operator.aggregation.DoubleRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.DoubleSumAggregation; import com.facebook.presto.operator.aggregation.EntropyAggregation; import com.facebook.presto.operator.aggregation.GeometricMeanAggregations; @@ -85,6 +86,7 @@ import com.facebook.presto.operator.aggregation.RealGeometricMeanAggregations; import com.facebook.presto.operator.aggregation.RealHistogramAggregation; import com.facebook.presto.operator.aggregation.RealRegressionAggregation; +import com.facebook.presto.operator.aggregation.RealRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.RealSumAggregation; import com.facebook.presto.operator.aggregation.ReduceAggregationFunction; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; @@ -742,7 +744,9 @@ private List getBuiltInFunctions(FunctionsConfig function .aggregates(DoubleCovarianceAggregation.class) .aggregates(RealCovarianceAggregation.class) .aggregates(DoubleRegressionAggregation.class) + .aggregates(DoubleRegressionExtendedAggregation.class) .aggregates(RealRegressionAggregation.class) + .aggregates(RealRegressionExtendedAggregation.class) .aggregates(DoubleCorrelationAggregation.class) .aggregates(RealCorrelationAggregation.class) .aggregates(BitwiseOrAggregation.class) diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java index 578e782bc8d91..c78186ebb2417 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java @@ -22,6 +22,7 @@ import com.facebook.presto.operator.aggregation.state.CentralMomentsState; import com.facebook.presto.operator.aggregation.state.CorrelationState; import com.facebook.presto.operator.aggregation.state.CovarianceState; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; import com.facebook.presto.operator.aggregation.state.RegressionState; import com.facebook.presto.operator.aggregation.state.VarianceState; import com.facebook.presto.spi.function.AggregationFunctionImplementation; @@ -145,9 +146,14 @@ public static double getCorrelation(CorrelationState state) public static void updateRegressionState(RegressionState state, double x, double y) { double oldMeanX = state.getMeanX(); - double oldMeanY = state.getMeanY(); updateCovarianceState(state, x, y); state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX())); + } + + public static void updateExtendedRegressionState(ExtendedRegressionState state, double x, double y) + { + double oldMeanY = state.getMeanY(); + updateRegressionState(state, x, y); state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY())); } @@ -189,12 +195,12 @@ public static double getRegressionSxy(RegressionState state) return state.getC2(); } - public static double getRegressionSyy(RegressionState state) + public static double getRegressionSyy(ExtendedRegressionState state) { return state.getM2Y(); } - public static double getRegressionR2(RegressionState state) + public static double getRegressionR2(ExtendedRegressionState state) { if (state.getM2X() != 0 && state.getM2Y() == 0) { return 1.0; @@ -311,10 +317,21 @@ public static void mergeRegressionState(RegressionState state, RegressionState o long na = state.getCount(); long nb = otherState.getCount(); state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb)); - state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); updateCovarianceState(state, otherState); } + public static void mergeExtendedRegressionState(ExtendedRegressionState state, ExtendedRegressionState otherState) + { + if (otherState.getCount() == 0) { + return; + } + + long na = state.getCount(); + long nb = otherState.getCount(); + state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); + mergeRegressionState(state, otherState); + } + public static String generateAggregationName(String baseName, TypeSignature outputType, List inputTypes) { StringBuilder sb = new StringBuilder(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java index 24d1c6e61fcf5..db3ad26ec5d6d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState; import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..3550cd0936949 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeExtendedRegressionState; +import static com.facebook.presto.operator.aggregation.AggregationUtils.updateExtendedRegressionState; + +@AggregationFunction +public class DoubleRegressionExtendedAggregation +{ + private DoubleRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.DOUBLE) double dependentValue, @SqlType(StandardTypes.DOUBLE) double independentValue) + { + updateExtendedRegressionState(state, independentValue, dependentValue); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + mergeExtendedRegressionState(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java index 1fe5d006da1a9..a75222bfa93c4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.RealType.REAL; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.intBitsToFloat; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.REAL) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.REAL) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.REAL) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.REAL) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.REAL) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..2d0335ae9aca6 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; + +@AggregationFunction +public class RealRegressionExtendedAggregation +{ + private RealRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.REAL) long dependentValue, @SqlType(StandardTypes.REAL) long independentValue) + { + DoubleRegressionExtendedAggregation.input(state, intBitsToFloat((int) dependentValue), intBitsToFloat((int) independentValue)); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + DoubleRegressionExtendedAggregation.combine(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.REAL) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.REAL) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.REAL) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.REAL) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.REAL) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java new file mode 100644 index 0000000000000..64a9883174158 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.state; + +public interface ExtendedRegressionState + extends RegressionState +{ + double getM2Y(); + + void setM2Y(double value); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java index 79837f90c0c11..ae3af6f46dc43 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java @@ -19,8 +19,4 @@ public interface RegressionState double getM2X(); void setM2X(double value); - - double getM2Y(); - - void setM2Y(double value); }